Skip to main content

lvqr_codec/
ts.rs

1//! Focused MPEG-TS demuxer for SRT and file-based ingest.
2//!
3//! Parses a byte stream of 188-byte TS packets, extracts PAT and
4//! PMT tables to discover elementary stream PIDs and types, and
5//! reassembles PES packets across TS packet boundaries. The
6//! caller feeds arbitrary byte chunks via [`TsDemuxer::feed`];
7//! the demuxer handles sync-byte recovery internally.
8//!
9//! Scope: PAT, single-program PMT, PES reassembly with PTS/DTS
10//! extraction for H.264 (0x1B), HEVC (0x24), and AAC (0x0F).
11//! Multi-program TS, SCTE-35, DVB descriptors, and PCR recovery
12//! are out of scope for this first cut; the SRT ingest path
13//! (Tier 2.8) only needs single-program demux from broadcast
14//! encoders.
15
16use std::collections::HashMap;
17
18const TS_PACKET_SIZE: usize = 188;
19const SYNC_BYTE: u8 = 0x47;
20const PAT_PID: u16 = 0;
21
22/// Elementary stream type codes from ISO/IEC 13818-1.
23#[derive(Debug, Clone, Copy, PartialEq, Eq)]
24pub enum StreamType {
25    H264,
26    H265,
27    Aac,
28    Unknown(u8),
29}
30
31impl StreamType {
32    fn from_byte(b: u8) -> Self {
33        match b {
34            0x1B => Self::H264,
35            0x24 => Self::H265,
36            0x0F | 0x11 => Self::Aac,
37            other => Self::Unknown(other),
38        }
39    }
40}
41
42/// One reassembled PES packet yielded by [`TsDemuxer::feed`].
43#[derive(Debug, Clone)]
44pub struct PesPacket {
45    pub pid: u16,
46    pub stream_type: StreamType,
47    /// Presentation timestamp in 90 kHz ticks. `None` when the
48    /// PES header does not carry a PTS (uncommon for video/audio).
49    pub pts: Option<u64>,
50    /// Decode timestamp in 90 kHz ticks. `None` when PTS == DTS
51    /// (most audio, non-B-frame video).
52    pub dts: Option<u64>,
53    /// Raw elementary stream bytes (Annex B for video, raw AAC
54    /// frame for audio after ADTS stripping if present).
55    pub payload: Vec<u8>,
56}
57
58/// Per-PID reassembly buffer.
59#[derive(Debug)]
60struct PesBuffer {
61    stream_type: StreamType,
62    buf: Vec<u8>,
63    started: bool,
64}
65
66/// MPEG-TS demuxer with sync recovery and PES reassembly.
67#[derive(Debug)]
68pub struct TsDemuxer {
69    /// Leftover bytes from the previous `feed` call that did not
70    /// align to a 188-byte boundary.
71    remainder: Vec<u8>,
72    /// PMT PID discovered from the PAT.
73    pmt_pid: Option<u16>,
74    /// Elementary stream PID -> stream type, populated from PMT.
75    streams: HashMap<u16, StreamType>,
76    /// Per-PID PES reassembly buffers.
77    pes_bufs: HashMap<u16, PesBuffer>,
78}
79
80impl Default for TsDemuxer {
81    fn default() -> Self {
82        Self::new()
83    }
84}
85
86impl TsDemuxer {
87    pub fn new() -> Self {
88        Self {
89            remainder: Vec::new(),
90            pmt_pid: None,
91            streams: HashMap::new(),
92            pes_bufs: HashMap::new(),
93        }
94    }
95
96    /// Feed an arbitrary byte slice into the demuxer. Returns
97    /// zero or more fully reassembled PES packets. The demuxer
98    /// handles sync-byte recovery and cross-call buffering
99    /// internally; callers may pass any chunk size.
100    pub fn feed(&mut self, data: &[u8]) -> Vec<PesPacket> {
101        let mut out = Vec::new();
102
103        // Fast path: drain any buffered remainder first by
104        // completing one packet from remainder + new data, then
105        // process aligned packets directly from the input slice
106        // without copying into the remainder buffer. This avoids
107        // O(N^2) drain cost for large inputs.
108        let input = if self.remainder.is_empty() {
109            data
110        } else {
111            self.remainder.extend_from_slice(data);
112            // Process everything from remainder, then clear it and
113            // return an empty slice so the main loop is skipped.
114            self.process_buf(&mut out);
115            &[]
116        };
117
118        // Process aligned packets directly from the input slice.
119        let mut pos = 0;
120        while pos < input.len() {
121            let sync_off = match input[pos..].iter().position(|&b| b == SYNC_BYTE) {
122                Some(p) => p,
123                None => break,
124            };
125            pos += sync_off;
126            if pos + TS_PACKET_SIZE > input.len() {
127                break;
128            }
129            let pkt: &[u8; TS_PACKET_SIZE] = input[pos..pos + TS_PACKET_SIZE].try_into().unwrap();
130            self.process_packet(pkt, &mut out);
131            pos += TS_PACKET_SIZE;
132        }
133
134        // Stash any trailing bytes for the next call.
135        if pos < input.len() {
136            self.remainder.extend_from_slice(&input[pos..]);
137        }
138
139        out
140    }
141
142    /// Drain the remainder buffer, processing complete packets.
143    fn process_buf(&mut self, out: &mut Vec<PesPacket>) {
144        let mut pos = 0;
145        while pos < self.remainder.len() {
146            let sync_off = match self.remainder[pos..].iter().position(|&b| b == SYNC_BYTE) {
147                Some(p) => p,
148                None => {
149                    self.remainder.clear();
150                    return;
151                }
152            };
153            pos += sync_off;
154            if pos + TS_PACKET_SIZE > self.remainder.len() {
155                break;
156            }
157            let pkt: [u8; TS_PACKET_SIZE] = self.remainder[pos..pos + TS_PACKET_SIZE].try_into().unwrap();
158            self.process_packet(&pkt, out);
159            pos += TS_PACKET_SIZE;
160        }
161        // Keep only the unprocessed tail.
162        if pos > 0 {
163            self.remainder.drain(..pos);
164        }
165    }
166
167    fn process_packet(&mut self, pkt: &[u8; TS_PACKET_SIZE], out: &mut Vec<PesPacket>) {
168        let pid = (((pkt[1] & 0x1F) as u16) << 8) | pkt[2] as u16;
169        let pusi = pkt[1] & 0x40 != 0;
170        let afc = (pkt[3] >> 4) & 0x03;
171
172        let payload_offset = match afc {
173            0b01 => 4,
174            0b11 => {
175                let af_len = pkt[4] as usize;
176                5 + af_len
177            }
178            _ => return,
179        };
180        if payload_offset >= TS_PACKET_SIZE {
181            return;
182        }
183        let payload = &pkt[payload_offset..];
184
185        if pid == PAT_PID {
186            self.parse_pat(payload, pusi);
187        } else if Some(pid) == self.pmt_pid {
188            self.parse_pmt(payload, pusi);
189        } else if self.streams.contains_key(&pid) {
190            self.push_pes(pid, payload, pusi, out);
191        }
192    }
193
194    fn parse_pat(&mut self, payload: &[u8], pusi: bool) {
195        let data = if pusi && !payload.is_empty() {
196            let pointer = payload[0] as usize;
197            if 1 + pointer >= payload.len() {
198                return;
199            }
200            &payload[1 + pointer..]
201        } else {
202            payload
203        };
204        // table_id(1) + flags/length(2) + ts_id(2) + version(1) +
205        // section/last(2) = 8 bytes header before the program loop.
206        if data.len() < 12 {
207            return;
208        }
209        let section_length = (((data[1] & 0x0F) as usize) << 8) | data[2] as usize;
210        let table_end = 3 + section_length;
211        if table_end > data.len() || section_length < 9 {
212            return;
213        }
214        // Program loop starts at byte 8, ends 4 bytes before CRC.
215        let loop_end = table_end.saturating_sub(4);
216        let mut i = 8;
217        while i + 4 <= loop_end {
218            let prog_num = ((data[i] as u16) << 8) | data[i + 1] as u16;
219            let map_pid = (((data[i + 2] & 0x1F) as u16) << 8) | data[i + 3] as u16;
220            if prog_num != 0 {
221                self.pmt_pid = Some(map_pid);
222                break;
223            }
224            i += 4;
225        }
226    }
227
228    fn parse_pmt(&mut self, payload: &[u8], pusi: bool) {
229        let data = if pusi && !payload.is_empty() {
230            let pointer = payload[0] as usize;
231            if 1 + pointer >= payload.len() {
232                return;
233            }
234            &payload[1 + pointer..]
235        } else {
236            payload
237        };
238        if data.len() < 16 {
239            return;
240        }
241        let section_length = (((data[1] & 0x0F) as usize) << 8) | data[2] as usize;
242        let table_end = 3 + section_length;
243        if table_end > data.len() || section_length < 13 {
244            return;
245        }
246        let prog_info_len = (((data[10] & 0x0F) as usize) << 8) | data[11] as usize;
247        let mut i = 12 + prog_info_len;
248        let loop_end = table_end.saturating_sub(4);
249        self.streams.clear();
250        while i + 5 <= loop_end {
251            let st = data[i];
252            let es_pid = (((data[i + 1] & 0x1F) as u16) << 8) | data[i + 2] as u16;
253            let es_info_len = (((data[i + 3] & 0x0F) as usize) << 8) | data[i + 4] as usize;
254            self.streams.insert(es_pid, StreamType::from_byte(st));
255            i += 5 + es_info_len;
256        }
257    }
258
259    fn push_pes(&mut self, pid: u16, payload: &[u8], pusi: bool, out: &mut Vec<PesPacket>) {
260        let stream_type = *self.streams.get(&pid).unwrap_or(&StreamType::Unknown(0));
261
262        if pusi {
263            if let Some(buf) = self.pes_bufs.get_mut(&pid) {
264                if buf.started && !buf.buf.is_empty() {
265                    if let Some(pkt) = Self::finish_pes(pid, buf) {
266                        out.push(pkt);
267                    }
268                }
269            }
270            let entry = self.pes_bufs.entry(pid).or_insert_with(|| PesBuffer {
271                stream_type,
272                buf: Vec::with_capacity(64 * 1024),
273                started: false,
274            });
275            entry.buf.clear();
276            entry.buf.extend_from_slice(payload);
277            entry.started = true;
278            entry.stream_type = stream_type;
279        } else if let Some(buf) = self.pes_bufs.get_mut(&pid) {
280            if buf.started {
281                buf.extend(payload);
282            }
283        }
284    }
285
286    fn finish_pes(pid: u16, buf: &mut PesBuffer) -> Option<PesPacket> {
287        let data = &buf.buf;
288        if data.len() < 9 || data[0] != 0 || data[1] != 0 || data[2] != 1 {
289            return None;
290        }
291        let pes_packet_length = ((data[4] as usize) << 8) | data[5] as usize;
292        let header_data_len = data[8] as usize;
293        let es_start = 9 + header_data_len;
294        if es_start > data.len() {
295            return None;
296        }
297        let flags = data[7];
298        let pts_flag = flags & 0x80 != 0;
299        let dts_flag = flags & 0x40 != 0;
300
301        let pts = if pts_flag && header_data_len >= 5 {
302            Some(parse_ts_timestamp(&data[9..14]))
303        } else {
304            None
305        };
306        let dts = if dts_flag && header_data_len >= 10 {
307            Some(parse_ts_timestamp(&data[14..19]))
308        } else {
309            None
310        };
311
312        // When PES_packet_length is non-zero, it specifies the
313        // exact number of bytes after the 6-byte PES header
314        // prefix. Use it to trim trailing TS padding. When zero
315        // (unbounded, common for video), take everything.
316        let es_end = if pes_packet_length > 0 {
317            (6 + pes_packet_length).min(data.len())
318        } else {
319            data.len()
320        };
321        let payload = data[es_start..es_end].to_vec();
322        if payload.is_empty() {
323            return None;
324        }
325
326        Some(PesPacket {
327            pid,
328            stream_type: buf.stream_type,
329            pts,
330            dts,
331            payload,
332        })
333    }
334}
335
336impl PesBuffer {
337    fn extend(&mut self, data: &[u8]) {
338        self.buf.extend_from_slice(data);
339    }
340}
341
342/// Parse a 33-bit MPEG-TS timestamp from the 5-byte PTS/DTS
343/// encoding with marker bits. The layout is:
344/// `0bXXXa_bbbY cccc_cccc YYYY_dddd eeee_eeeY`
345/// where a-e are the 33 timestamp bits and X/Y are markers.
346fn parse_ts_timestamp(b: &[u8]) -> u64 {
347    let a = ((b[0] as u64 >> 1) & 0x07) << 30;
348    let bc = ((b[1] as u64) << 7 | (b[2] as u64 >> 1)) << 15;
349    let de = (b[3] as u64) << 7 | (b[4] as u64 >> 1);
350    a | bc | de
351}
352
353#[cfg(test)]
354mod tests {
355    use super::*;
356
357    fn make_ts_packet(pid: u16, pusi: bool, payload: &[u8]) -> [u8; 188] {
358        let mut pkt = [0xFFu8; 188];
359        pkt[0] = SYNC_BYTE;
360        pkt[1] = if pusi { 0x40 } else { 0x00 } | ((pid >> 8) as u8 & 0x1F);
361        pkt[2] = pid as u8;
362        pkt[3] = 0x10; // payload only, CC=0
363        let copy_len = payload.len().min(184);
364        pkt[4..4 + copy_len].copy_from_slice(&payload[..copy_len]);
365        // Stuff remaining bytes with 0xFF (already done by init).
366        pkt
367    }
368
369    fn minimal_pat(pmt_pid: u16) -> Vec<u8> {
370        // pointer_field(1) + table_id(1) + flags/length(2) +
371        // ts_id(2) + version(1) + section(1) + last_section(1)
372        // + program_number(2) + reserved/pmt_pid(2) + CRC(4)
373        let mut data = vec![
374            0x00, // pointer field
375            0x00, // table_id = PAT
376            0xB0, 0x0D, // section_syntax + length = 13
377            0x00, 0x01, // transport_stream_id
378            0xC1, // version=0, current
379            0x00, 0x00, // section 0 of 0
380            0x00, 0x01, // program_number = 1
381        ];
382        data.push(0xE0 | ((pmt_pid >> 8) as u8 & 0x1F));
383        data.push(pmt_pid as u8);
384        data.extend_from_slice(&[0x00; 4]); // CRC placeholder
385        data
386    }
387
388    fn minimal_pmt(video_pid: u16, audio_pid: u16) -> Vec<u8> {
389        // pointer_field + table_id + flags/length + program_number +
390        // version + section + pcr_pid + program_info_length +
391        // stream entries + CRC
392        let mut data = vec![
393            0x00, // pointer field
394            0x02, // table_id = PMT
395            0xB0, 0x17, // section_syntax + length = 23
396            0x00, 0x01, // program_number = 1
397            0xC1, // version=0, current
398            0x00, 0x00, // section 0 of 0
399            0xE1, 0x00, // PCR_PID = 0x100
400            0xF0, 0x00, // program_info_length = 0
401        ];
402        // Video stream entry: H.264
403        data.push(0x1B); // stream_type
404        data.push(0xE0 | ((video_pid >> 8) as u8 & 0x1F));
405        data.push(video_pid as u8);
406        data.push(0xF0);
407        data.push(0x00); // ES_info_length = 0
408        // Audio stream entry: AAC
409        data.push(0x0F); // stream_type
410        data.push(0xE0 | ((audio_pid >> 8) as u8 & 0x1F));
411        data.push(audio_pid as u8);
412        data.push(0xF0);
413        data.push(0x00); // ES_info_length = 0
414        data.extend_from_slice(&[0x00; 4]); // CRC placeholder
415        data
416    }
417
418    fn minimal_pes(pts_90k: u64, es_payload: &[u8]) -> Vec<u8> {
419        // PES_packet_length = header (3 bytes: flags + PTS flag +
420        // header_data_length) + PTS (5 bytes) + ES payload.
421        let pes_len = (3 + 5 + es_payload.len()) as u16;
422        let mut data = vec![
423            0x00,
424            0x00,
425            0x01, // start code
426            0xE0, // stream_id (video)
427            (pes_len >> 8) as u8,
428            pes_len as u8,
429            0x80, // marker bits
430            0x80, // PTS flag set, no DTS
431            0x05, // header_data_length = 5
432        ];
433        // Encode PTS into 5 bytes with marker bits.
434        let pts = pts_90k & 0x1_FFFF_FFFF;
435        data.push(0x21 | ((pts >> 29) as u8 & 0x0E));
436        data.push((pts >> 22) as u8);
437        data.push(0x01 | ((pts >> 14) as u8 & 0xFE));
438        data.push((pts >> 7) as u8);
439        data.push(0x01 | ((pts << 1) as u8 & 0xFE));
440        data.extend_from_slice(es_payload);
441        data
442    }
443
444    #[test]
445    fn demux_discovers_streams_and_yields_pes() {
446        let mut demux = TsDemuxer::new();
447        let video_pid = 0x100;
448        let audio_pid = 0x101;
449        let pmt_pid = 0x1000;
450
451        // Feed PAT.
452        let pat = make_ts_packet(PAT_PID, true, &minimal_pat(pmt_pid));
453        assert!(demux.feed(&pat).is_empty());
454        assert_eq!(demux.pmt_pid, Some(pmt_pid));
455
456        // Feed PMT.
457        let pmt = make_ts_packet(pmt_pid, true, &minimal_pmt(video_pid, audio_pid));
458        assert!(demux.feed(&pmt).is_empty());
459        assert_eq!(demux.streams.len(), 2);
460        assert_eq!(demux.streams[&video_pid], StreamType::H264);
461        assert_eq!(demux.streams[&audio_pid], StreamType::Aac);
462
463        // Feed a PES packet for video.
464        let pes = minimal_pes(90_000, b"nalunalunalu");
465        let pkt = make_ts_packet(video_pid, true, &pes);
466        // PES is not yielded until the next PUSI on the same PID.
467        assert!(demux.feed(&pkt).is_empty());
468
469        // Start a new PES on the same PID to flush the previous one.
470        let pes2 = minimal_pes(180_000, b"nalu2");
471        let pkt2 = make_ts_packet(video_pid, true, &pes2);
472        let packets = demux.feed(&pkt2);
473        assert_eq!(packets.len(), 1);
474        assert_eq!(packets[0].pid, video_pid);
475        assert_eq!(packets[0].stream_type, StreamType::H264);
476        assert_eq!(packets[0].pts, Some(90_000));
477        assert_eq!(packets[0].payload, b"nalunalunalu");
478    }
479
480    #[test]
481    fn sync_recovery_skips_garbage() {
482        let mut demux = TsDemuxer::new();
483        let pmt_pid = 0x1000;
484
485        // Feed garbage followed by a valid PAT packet.
486        let mut data = vec![0xDE, 0xAD, 0xBE, 0xEF];
487        data.extend_from_slice(&make_ts_packet(PAT_PID, true, &minimal_pat(pmt_pid)));
488        demux.feed(&data);
489        assert_eq!(demux.pmt_pid, Some(pmt_pid));
490    }
491
492    #[test]
493    fn cross_call_buffering_handles_partial_packets() {
494        let mut demux = TsDemuxer::new();
495        let pmt_pid = 0x1000;
496        let full = make_ts_packet(PAT_PID, true, &minimal_pat(pmt_pid));
497
498        // Feed first half.
499        demux.feed(&full[..100]);
500        assert_eq!(demux.pmt_pid, None);
501
502        // Feed second half.
503        demux.feed(&full[100..]);
504        assert_eq!(demux.pmt_pid, Some(pmt_pid));
505    }
506
507    #[test]
508    fn parse_ts_timestamp_round_trips() {
509        let pts: u64 = 123_456_789;
510        let mut buf = [0u8; 5];
511        buf[0] = 0x21 | ((pts >> 29) as u8 & 0x0E);
512        buf[1] = (pts >> 22) as u8;
513        buf[2] = 0x01 | ((pts >> 14) as u8 & 0xFE);
514        buf[3] = (pts >> 7) as u8;
515        buf[4] = 0x01 | ((pts << 1) as u8 & 0xFE);
516        assert_eq!(parse_ts_timestamp(&buf), pts);
517    }
518}