Skip to main content

arcly_stream/protocol/srt/
ts.rs

1//! A focused MPEG-TS demuxer for SRT ingest.
2//!
3//! SRT carries an MPEG-TS bytestream in its data packets. This demuxer consumes
4//! 188-byte TS packets, follows the PAT → PMT → elementary-PID chain, reassembles
5//! PES packets on the video PID, and emits one [`TsPayload`] per access unit in
6//! Annex-B form with a decoded PTS.
7//!
8//! It is deliberately small: single program, first video **and** first audio
9//! elementary stream. Video: H.264 (`stream_type` 0x1B) and H.265 (0x24); audio:
10//! AAC (0x0F ADTS, 0x11 LATM) and MP3 (0x03/0x04). Continuity-counter gaps are
11//! tolerated — a lost packet simply truncates the in-progress PES, which is
12//! dropped.
13
14use crate::CodecId;
15use bytes::Bytes;
16
17/// Whether a demuxed [`TsPayload`] is a video or audio access unit.
18#[derive(Debug, Clone, Copy, PartialEq, Eq)]
19pub enum TsTrackKind {
20    /// A video access unit (Annex-B framed).
21    Video,
22    /// An audio access unit (e.g. ADTS AAC).
23    Audio,
24}
25
26/// One reassembled access unit demuxed from the TS stream.
27#[derive(Debug, Clone, PartialEq, Eq)]
28pub struct TsPayload {
29    /// Elementary-stream access unit (Annex-B framed for H.264/H.265 video,
30    /// raw elementary bytes for audio).
31    pub data: Bytes,
32    /// Codec identified from the PMT `stream_type`.
33    pub codec: CodecId,
34    /// Whether this is a video or audio access unit.
35    pub kind: TsTrackKind,
36    /// Presentation timestamp in milliseconds (PES PTS / 90).
37    pub pts_ms: i64,
38    /// Whether the access unit holds a keyframe (IDR). Always `false` for audio.
39    pub keyframe: bool,
40}
41
42/// Per-elementary-stream PES reassembly state (one per video / audio PID).
43#[derive(Debug)]
44struct Track {
45    /// Elementary PID this track reassembles.
46    pid: u16,
47    /// Codec from the PMT `stream_type`.
48    codec: CodecId,
49    /// Video vs audio.
50    kind: TsTrackKind,
51    /// PES reassembly buffer.
52    pes: Vec<u8>,
53    /// PTS (90 kHz) of the PES currently being reassembled.
54    pts: i64,
55    /// Whether a PES is open (between two PUSI markers).
56    open: bool,
57}
58
59impl Track {
60    fn new(pid: u16, codec: CodecId, kind: TsTrackKind) -> Self {
61        Self {
62            pid,
63            codec,
64            kind,
65            pes: Vec::new(),
66            pts: 0,
67            open: false,
68        }
69    }
70
71    /// Feed one TS payload for this PID, flushing the prior access unit on PUSI.
72    fn feed(&mut self, payload: &[u8], pusi: bool, out: &mut Vec<TsPayload>) {
73        if pusi {
74            self.flush(out);
75            if let Some((pts, es_offset)) = parse_pes_header(payload) {
76                self.pts = pts;
77                self.open = true;
78                self.pes.extend_from_slice(&payload[es_offset..]);
79            }
80        } else if self.open {
81            self.pes.extend_from_slice(payload);
82        }
83    }
84
85    /// Emit the buffered PES as an access unit, if any.
86    fn flush(&mut self, out: &mut Vec<TsPayload>) {
87        if !self.open || self.pes.is_empty() {
88            self.pes.clear();
89            self.open = false;
90            return;
91        }
92        let es = std::mem::take(&mut self.pes);
93        let keyframe = matches!(self.kind, TsTrackKind::Video) && is_keyframe(&es, self.codec);
94        out.push(TsPayload {
95            data: Bytes::from(es),
96            codec: self.codec,
97            kind: self.kind,
98            pts_ms: self.pts / 90,
99            keyframe,
100        });
101        self.open = false;
102    }
103}
104
105const TS_PACKET_LEN: usize = 188;
106const TS_SYNC: u8 = 0x47;
107
108/// Stateful MPEG-TS demuxer. Feed it TS bytes with [`push`](Self::push).
109#[derive(Debug)]
110pub struct TsDemuxer {
111    /// PID carrying the PMT, learned from the PAT (PID 0).
112    pmt_pid: Option<u16>,
113    /// The first video elementary stream, learned from the PMT.
114    video: Option<Track>,
115    /// The first audio elementary stream, learned from the PMT.
116    audio: Option<Track>,
117    /// Carry for TS bytes that span `push` calls but don't fill a packet.
118    carry: Vec<u8>,
119}
120
121impl Default for TsDemuxer {
122    fn default() -> Self {
123        Self::new()
124    }
125}
126
127impl TsDemuxer {
128    /// A fresh demuxer that has not yet seen a PAT.
129    pub fn new() -> Self {
130        Self {
131            pmt_pid: None,
132            video: None,
133            audio: None,
134            carry: Vec::new(),
135        }
136    }
137
138    /// Push a chunk of the TS bytestream, returning any access units completed.
139    pub fn push(&mut self, bytes: &[u8]) -> Vec<TsPayload> {
140        let mut out = Vec::new();
141        // Prepend any carried partial packet.
142        let mut data = std::mem::take(&mut self.carry);
143        data.extend_from_slice(bytes);
144
145        let mut i = 0;
146        while i + TS_PACKET_LEN <= data.len() {
147            let pkt = &data[i..i + TS_PACKET_LEN];
148            if pkt[0] == TS_SYNC {
149                self.handle_packet(pkt, &mut out);
150                i += TS_PACKET_LEN;
151            } else {
152                // Resync: skip a byte and look for the next sync.
153                i += 1;
154            }
155        }
156        // Carry the remainder (a partial packet) to the next call.
157        self.carry = data[i..].to_vec();
158        out
159    }
160
161    /// Process one 188-byte TS packet.
162    fn handle_packet(&mut self, pkt: &[u8], out: &mut Vec<TsPayload>) {
163        let pusi = pkt[1] & 0x40 != 0;
164        let pid = (((pkt[1] & 0x1F) as u16) << 8) | pkt[2] as u16;
165        let adaptation = (pkt[3] >> 4) & 0x03;
166        let has_payload = adaptation == 1 || adaptation == 3;
167        if !has_payload {
168            return;
169        }
170        // Skip the adaptation field if present.
171        let mut payload_start = 4;
172        if adaptation == 3 {
173            let af_len = pkt[4] as usize;
174            payload_start = 5 + af_len;
175        }
176        if payload_start >= TS_PACKET_LEN {
177            return;
178        }
179        let payload = &pkt[payload_start..];
180
181        if pid == 0 {
182            self.parse_pat(payload, pusi);
183        } else if Some(pid) == self.pmt_pid {
184            self.parse_pmt(payload, pusi);
185        } else if let Some(track) = self
186            .video
187            .as_mut()
188            .filter(|t| t.pid == pid)
189            .or_else(|| self.audio.as_mut().filter(|t| t.pid == pid))
190        {
191            track.feed(payload, pusi, out);
192        }
193    }
194
195    /// Parse the PAT to learn the PMT PID (first program).
196    fn parse_pat(&mut self, payload: &[u8], pusi: bool) {
197        let section = section_body(payload, pusi);
198        let Some(section) = section else { return };
199        // PAT entries start after the 8-byte section header, 4 bytes each.
200        let mut i = 8;
201        while i + 4 <= section.len().saturating_sub(4) {
202            let program = u16::from_be_bytes([section[i], section[i + 1]]);
203            let pid = (((section[i + 2] & 0x1F) as u16) << 8) | section[i + 3] as u16;
204            if program != 0 {
205                self.pmt_pid = Some(pid);
206                return;
207            }
208            i += 4;
209        }
210    }
211
212    /// Parse the PMT to learn the first video and first audio elementary PID.
213    fn parse_pmt(&mut self, payload: &[u8], pusi: bool) {
214        let Some(section) = section_body(payload, pusi) else {
215            return;
216        };
217        if section.len() < 12 {
218            return;
219        }
220        let program_info_len = (((section[10] & 0x0F) as usize) << 8) | section[11] as usize;
221        let mut i = 12 + program_info_len;
222        while i + 5 <= section.len().saturating_sub(4) {
223            let stream_type = section[i];
224            let pid = (((section[i + 1] & 0x1F) as u16) << 8) | section[i + 2] as u16;
225            let es_info_len = (((section[i + 3] & 0x0F) as usize) << 8) | section[i + 4] as usize;
226            match stream_type_to_track(stream_type) {
227                Some((codec, TsTrackKind::Video)) if self.video.is_none() => {
228                    self.video = Some(Track::new(pid, codec, TsTrackKind::Video));
229                }
230                Some((codec, TsTrackKind::Audio)) if self.audio.is_none() => {
231                    self.audio = Some(Track::new(pid, codec, TsTrackKind::Audio));
232                }
233                _ => {}
234            }
235            i += 5 + es_info_len;
236        }
237    }
238}
239
240/// Map a PMT `stream_type` to a codec and track kind, or `None` if unsupported.
241fn stream_type_to_track(stream_type: u8) -> Option<(CodecId, TsTrackKind)> {
242    match stream_type {
243        0x1B => Some((CodecId::H264, TsTrackKind::Video)),
244        0x24 => Some((CodecId::H265, TsTrackKind::Video)),
245        0x0F | 0x11 => Some((CodecId::AAC, TsTrackKind::Audio)), // ADTS / LATM AAC
246        0x03 | 0x04 => Some((CodecId::MP3, TsTrackKind::Audio)),
247        _ => None,
248    }
249}
250
251/// Extract the PSI section body from a TS payload, honoring the `pointer_field`
252/// that precedes a section in a PUSI packet.
253fn section_body(payload: &[u8], pusi: bool) -> Option<&[u8]> {
254    if pusi {
255        let pointer = *payload.first()? as usize;
256        payload.get(1 + pointer..)
257    } else {
258        Some(payload)
259    }
260}
261
262/// Parse a PES header, returning `(pts_90khz, es_payload_offset)`.
263fn parse_pes_header(p: &[u8]) -> Option<(i64, usize)> {
264    // Start code 00 00 01, stream_id, 2-byte length, then the optional header.
265    if p.len() < 9 || p[0] != 0 || p[1] != 0 || p[2] != 1 {
266        return None;
267    }
268    let header_data_len = p[8] as usize;
269    // Clamp to the payload: a corrupt/hostile PES header may declare a length
270    // that runs past the packet, so `es_offset` could otherwise exceed `p.len()`
271    // and panic when the caller slices `payload[es_offset..]`.
272    let es_offset = (9 + header_data_len).min(p.len());
273    let pts_dts_flags = p[7] >> 6;
274    let pts = if pts_dts_flags & 0x02 != 0 && p.len() >= 14 {
275        // 33-bit PTS spread across 5 bytes with marker bits.
276        let b = &p[9..14];
277        (((b[0] as i64 >> 1) & 0x07) << 30)
278            | ((b[1] as i64) << 22)
279            | (((b[2] as i64 >> 1) & 0x7F) << 15)
280            | ((b[3] as i64) << 7)
281            | ((b[4] as i64 >> 1) & 0x7F)
282    } else {
283        0
284    };
285    Some((pts, es_offset))
286}
287
288/// Detect a keyframe in an Annex-B elementary access unit.
289fn is_keyframe(es: &[u8], codec: CodecId) -> bool {
290    let mut i = 0;
291    while i + 4 < es.len() {
292        // Match 3- or 4-byte start codes.
293        let sc3 = es[i] == 0 && es[i + 1] == 0 && es[i + 2] == 1;
294        let sc4 = es[i] == 0 && es[i + 1] == 0 && es[i + 2] == 0 && es[i + 3] == 1;
295        if sc3 || sc4 {
296            let nal_off = if sc4 { i + 4 } else { i + 3 };
297            if let Some(&hdr) = es.get(nal_off) {
298                match codec {
299                    CodecId::H264 if hdr & 0x1F == 5 => return true,
300                    CodecId::H265 if (16..=21).contains(&((hdr >> 1) & 0x3F)) => return true,
301                    _ => {}
302                }
303            }
304            i = nal_off;
305        } else {
306            i += 1;
307        }
308    }
309    false
310}
311
312#[cfg(test)]
313mod tests {
314    use super::*;
315
316    /// Build a 188-byte TS packet for `pid` with `pusi` and the given payload.
317    fn ts_packet(pid: u16, pusi: bool, payload: &[u8]) -> Vec<u8> {
318        let mut pkt = vec![0u8; TS_PACKET_LEN];
319        pkt[0] = TS_SYNC;
320        pkt[1] = if pusi { 0x40 } else { 0 } | ((pid >> 8) as u8 & 0x1F);
321        pkt[2] = (pid & 0xFF) as u8;
322        pkt[3] = 0x10; // payload only, cc 0
323        let n = payload.len().min(TS_PACKET_LEN - 4);
324        pkt[4..4 + n].copy_from_slice(&payload[..n]);
325        pkt
326    }
327
328    /// A PAT pointing program 1 at PMT PID 0x1000.
329    fn pat() -> Vec<u8> {
330        let mut sec = vec![0u8]; // pointer_field = 0
331                                 // table_id 0, section header (8 bytes from table_id), then one program.
332        sec.extend_from_slice(&[0x00, 0xB0, 0x0D, 0, 0, 0xC1, 0, 0]);
333        sec.extend_from_slice(&[0x00, 0x01]); // program_number 1
334        sec.extend_from_slice(&[0xE0 | 0x10, 0x00]); // PMT PID 0x1000
335        sec.extend_from_slice(&[0, 0, 0, 0]); // CRC placeholder
336        ts_packet(0, true, &sec)
337    }
338
339    /// A PMT declaring an H.264 video stream on PID 0x0100, optionally followed
340    /// by an AAC audio stream on PID 0x0101.
341    fn pmt_with(audio: bool) -> Vec<u8> {
342        let mut sec = vec![0u8]; // pointer_field
343                                 // table_id 2, header to program_info_length.
344        sec.extend_from_slice(&[0x02, 0xB0, 0x12, 0, 0x01, 0xC1, 0, 0]);
345        sec.extend_from_slice(&[0xE1, 0x00]); // PCR PID
346        sec.extend_from_slice(&[0xF0, 0x00]); // program_info_length 0
347        sec.extend_from_slice(&[0x1B, 0xE1, 0x00, 0xF0, 0x00]); // H.264 on PID 0x100
348        if audio {
349            sec.extend_from_slice(&[0x0F, 0xE1, 0x01, 0xF0, 0x00]); // AAC on PID 0x101
350        }
351        sec.extend_from_slice(&[0, 0, 0, 0]); // CRC placeholder
352        ts_packet(0x1000, true, &sec)
353    }
354
355    fn pmt() -> Vec<u8> {
356        pmt_with(false)
357    }
358
359    /// A PES packet on `pid` with `stream_id` wrapping `es` with a PTS.
360    fn pes_on(pid: u16, stream_id: u8, es: &[u8], pts: i64) -> Vec<u8> {
361        let mut p = vec![0x00, 0x00, 0x01, stream_id, 0x00, 0x00, 0x80, 0x80, 0x05];
362        // 5-byte PTS with marker bits.
363        let pts = pts as u64;
364        p.push((0x21 | (((pts >> 30) & 0x07) << 1)) as u8);
365        p.push(((pts >> 22) & 0xFF) as u8);
366        p.push((0x01 | (((pts >> 15) & 0x7F) << 1)) as u8);
367        p.push(((pts >> 7) & 0xFF) as u8);
368        p.push((0x01 | ((pts & 0x7F) << 1)) as u8);
369        p.extend_from_slice(es);
370        ts_packet(pid, true, &p)
371    }
372
373    /// A PES packet on the video PID 0x100 (stream_id 0xE0).
374    fn video_pes(es: &[u8], pts: i64) -> Vec<u8> {
375        pes_on(0x0100, 0xE0, es, pts)
376    }
377
378    /// A PES packet on the audio PID 0x101 (stream_id 0xC0).
379    fn audio_pes(es: &[u8], pts: i64) -> Vec<u8> {
380        pes_on(0x0101, 0xC0, es, pts)
381    }
382
383    #[test]
384    fn pes_header_decodes_pts() {
385        // Reuse the builder's PES bytes (strip the 4-byte TS header).
386        let pes = video_pes(&[], 90_000);
387        let (pts, _off) = parse_pes_header(&pes[4..]).unwrap();
388        assert_eq!(pts, 90_000);
389    }
390
391    #[test]
392    fn keyframe_detection_h264_idr() {
393        let idr = [0, 0, 0, 1, 0x65, 0xAA];
394        assert!(is_keyframe(&idr, CodecId::H264));
395        let non_idr = [0, 0, 0, 1, 0x41, 0xAA];
396        assert!(!is_keyframe(&non_idr, CodecId::H264));
397    }
398
399    #[test]
400    fn full_chain_pat_pmt_pes_emits_access_unit() {
401        let mut d = TsDemuxer::new();
402        assert!(d.push(&pat()).is_empty());
403        assert!(d.push(&pmt()).is_empty());
404        assert_eq!(d.video.as_ref().unwrap().pid, 0x0100);
405        assert_eq!(d.video.as_ref().unwrap().codec, CodecId::H264);
406
407        // First PES opens; its AU is emitted when the next PES (PUSI) arrives.
408        let idr = [0, 0, 0, 1, 0x65, 0x11, 0x22];
409        assert!(d.push(&video_pes(&idr, 9000)).is_empty());
410        let delta = [0, 0, 0, 1, 0x41, 0x33];
411        let out = d.push(&video_pes(&delta, 12000));
412        assert_eq!(out.len(), 1);
413        assert_eq!(out[0].codec, CodecId::H264);
414        assert_eq!(out[0].kind, TsTrackKind::Video);
415        assert_eq!(out[0].pts_ms, 100); // 9000 / 90
416        assert!(out[0].keyframe);
417        // The access unit begins with the IDR NAL (the fixed-size test packet
418        // zero-pads past the payload; valid TS uses adaptation-field stuffing).
419        assert!(out[0].data.starts_with(&idr));
420    }
421
422    #[test]
423    fn carries_partial_packet_across_pushes() {
424        let mut d = TsDemuxer::new();
425        let p = pat();
426        // Feed the PAT split mid-packet; the demuxer must carry the remainder.
427        assert!(d.push(&p[..100]).is_empty());
428        assert!(d.push(&p[100..]).is_empty());
429        // Then a PMT in one shot resolves the video PID.
430        d.push(&pmt());
431        assert_eq!(d.video.as_ref().unwrap().pid, 0x0100);
432    }
433
434    #[test]
435    fn demuxes_audio_track_alongside_video() {
436        let mut d = TsDemuxer::new();
437        d.push(&pat());
438        d.push(&pmt_with(true));
439        // Both tracks are now known.
440        assert_eq!(d.audio.as_ref().unwrap().pid, 0x0101);
441        assert_eq!(d.audio.as_ref().unwrap().codec, CodecId::AAC);
442
443        // Open an audio PES, then a second one to flush the first.
444        let adts = [0xFF, 0xF1, 0x4C, 0x80, 0x01, 0x23];
445        assert!(d.push(&audio_pes(&adts, 18000)).is_empty());
446        let out = d.push(&audio_pes(&[0xFF, 0xF1, 0x00], 19000));
447        assert_eq!(out.len(), 1);
448        let au = &out[0];
449        assert_eq!(au.kind, TsTrackKind::Audio);
450        assert_eq!(au.codec, CodecId::AAC);
451        assert!(!au.keyframe, "audio access units are never keyframes");
452        assert_eq!(au.pts_ms, 200); // 18000 / 90
453        assert!(au.data.starts_with(&adts));
454    }
455
456    #[test]
457    fn pes_header_with_oversized_declared_length_is_clamped() {
458        // Regression (found by the ts_demux fuzz target): a PES header declaring
459        // header_data_len far past the payload must not panic when sliced.
460        let p = [0x00, 0x00, 0x01, 0xE0, 0x00, 0x00, 0x80, 0x00, 0xFF, 0xAA];
461        let (_pts, es_offset) = parse_pes_header(&p).unwrap();
462        assert_eq!(es_offset, p.len(), "offset clamped to payload length");
463        // Slicing at the returned offset is always valid.
464        let _ = &p[es_offset..];
465    }
466
467    #[test]
468    fn demuxer_survives_oversized_pes_header() {
469        // The full path the fuzzer hit: a video PES whose declared header length
470        // overruns the TS packet.
471        let mut d = TsDemuxer::new();
472        d.push(&pat());
473        d.push(&pmt());
474        let mut pes = vec![0x00, 0x00, 0x01, 0xE0, 0x00, 0x00, 0x80, 0x00, 0xFF];
475        pes.extend_from_slice(&[0x11, 0x22]); // far less data than declared
476                                              // Must not panic.
477        let _ = d.push(&ts_packet(0x0100, true, &pes));
478    }
479
480    #[test]
481    fn audio_only_stream_type_maps_to_aac() {
482        assert_eq!(
483            stream_type_to_track(0x0F),
484            Some((CodecId::AAC, TsTrackKind::Audio))
485        );
486        assert_eq!(
487            stream_type_to_track(0x03),
488            Some((CodecId::MP3, TsTrackKind::Audio))
489        );
490        assert!(stream_type_to_track(0x99).is_none());
491    }
492}