1use std::collections::HashMap;
17
18const TS_PACKET_SIZE: usize = 188;
19const SYNC_BYTE: u8 = 0x47;
20const PAT_PID: u16 = 0;
21
22#[derive(Debug, Clone, Copy, PartialEq, Eq)]
24pub enum StreamType {
25 H264,
26 H265,
27 Aac,
28 Unknown(u8),
29}
30
31impl StreamType {
32 fn from_byte(b: u8) -> Self {
33 match b {
34 0x1B => Self::H264,
35 0x24 => Self::H265,
36 0x0F | 0x11 => Self::Aac,
37 other => Self::Unknown(other),
38 }
39 }
40}
41
42#[derive(Debug, Clone)]
44pub struct PesPacket {
45 pub pid: u16,
46 pub stream_type: StreamType,
47 pub pts: Option<u64>,
50 pub dts: Option<u64>,
53 pub payload: Vec<u8>,
56}
57
58#[derive(Debug)]
60struct PesBuffer {
61 stream_type: StreamType,
62 buf: Vec<u8>,
63 started: bool,
64}
65
66#[derive(Debug)]
68pub struct TsDemuxer {
69 remainder: Vec<u8>,
72 pmt_pid: Option<u16>,
74 streams: HashMap<u16, StreamType>,
76 pes_bufs: HashMap<u16, PesBuffer>,
78}
79
80impl Default for TsDemuxer {
81 fn default() -> Self {
82 Self::new()
83 }
84}
85
86impl TsDemuxer {
87 pub fn new() -> Self {
88 Self {
89 remainder: Vec::new(),
90 pmt_pid: None,
91 streams: HashMap::new(),
92 pes_bufs: HashMap::new(),
93 }
94 }
95
96 pub fn feed(&mut self, data: &[u8]) -> Vec<PesPacket> {
101 let mut out = Vec::new();
102
103 let input = if self.remainder.is_empty() {
109 data
110 } else {
111 self.remainder.extend_from_slice(data);
112 self.process_buf(&mut out);
115 &[]
116 };
117
118 let mut pos = 0;
120 while pos < input.len() {
121 let sync_off = match input[pos..].iter().position(|&b| b == SYNC_BYTE) {
122 Some(p) => p,
123 None => break,
124 };
125 pos += sync_off;
126 if pos + TS_PACKET_SIZE > input.len() {
127 break;
128 }
129 let pkt: &[u8; TS_PACKET_SIZE] = input[pos..pos + TS_PACKET_SIZE].try_into().unwrap();
130 self.process_packet(pkt, &mut out);
131 pos += TS_PACKET_SIZE;
132 }
133
134 if pos < input.len() {
136 self.remainder.extend_from_slice(&input[pos..]);
137 }
138
139 out
140 }
141
142 fn process_buf(&mut self, out: &mut Vec<PesPacket>) {
144 let mut pos = 0;
145 while pos < self.remainder.len() {
146 let sync_off = match self.remainder[pos..].iter().position(|&b| b == SYNC_BYTE) {
147 Some(p) => p,
148 None => {
149 self.remainder.clear();
150 return;
151 }
152 };
153 pos += sync_off;
154 if pos + TS_PACKET_SIZE > self.remainder.len() {
155 break;
156 }
157 let pkt: [u8; TS_PACKET_SIZE] = self.remainder[pos..pos + TS_PACKET_SIZE].try_into().unwrap();
158 self.process_packet(&pkt, out);
159 pos += TS_PACKET_SIZE;
160 }
161 if pos > 0 {
163 self.remainder.drain(..pos);
164 }
165 }
166
167 fn process_packet(&mut self, pkt: &[u8; TS_PACKET_SIZE], out: &mut Vec<PesPacket>) {
168 let pid = (((pkt[1] & 0x1F) as u16) << 8) | pkt[2] as u16;
169 let pusi = pkt[1] & 0x40 != 0;
170 let afc = (pkt[3] >> 4) & 0x03;
171
172 let payload_offset = match afc {
173 0b01 => 4,
174 0b11 => {
175 let af_len = pkt[4] as usize;
176 5 + af_len
177 }
178 _ => return,
179 };
180 if payload_offset >= TS_PACKET_SIZE {
181 return;
182 }
183 let payload = &pkt[payload_offset..];
184
185 if pid == PAT_PID {
186 self.parse_pat(payload, pusi);
187 } else if Some(pid) == self.pmt_pid {
188 self.parse_pmt(payload, pusi);
189 } else if self.streams.contains_key(&pid) {
190 self.push_pes(pid, payload, pusi, out);
191 }
192 }
193
194 fn parse_pat(&mut self, payload: &[u8], pusi: bool) {
195 let data = if pusi && !payload.is_empty() {
196 let pointer = payload[0] as usize;
197 if 1 + pointer >= payload.len() {
198 return;
199 }
200 &payload[1 + pointer..]
201 } else {
202 payload
203 };
204 if data.len() < 12 {
207 return;
208 }
209 let section_length = (((data[1] & 0x0F) as usize) << 8) | data[2] as usize;
210 let table_end = 3 + section_length;
211 if table_end > data.len() || section_length < 9 {
212 return;
213 }
214 let loop_end = table_end.saturating_sub(4);
216 let mut i = 8;
217 while i + 4 <= loop_end {
218 let prog_num = ((data[i] as u16) << 8) | data[i + 1] as u16;
219 let map_pid = (((data[i + 2] & 0x1F) as u16) << 8) | data[i + 3] as u16;
220 if prog_num != 0 {
221 self.pmt_pid = Some(map_pid);
222 break;
223 }
224 i += 4;
225 }
226 }
227
228 fn parse_pmt(&mut self, payload: &[u8], pusi: bool) {
229 let data = if pusi && !payload.is_empty() {
230 let pointer = payload[0] as usize;
231 if 1 + pointer >= payload.len() {
232 return;
233 }
234 &payload[1 + pointer..]
235 } else {
236 payload
237 };
238 if data.len() < 16 {
239 return;
240 }
241 let section_length = (((data[1] & 0x0F) as usize) << 8) | data[2] as usize;
242 let table_end = 3 + section_length;
243 if table_end > data.len() || section_length < 13 {
244 return;
245 }
246 let prog_info_len = (((data[10] & 0x0F) as usize) << 8) | data[11] as usize;
247 let mut i = 12 + prog_info_len;
248 let loop_end = table_end.saturating_sub(4);
249 self.streams.clear();
250 while i + 5 <= loop_end {
251 let st = data[i];
252 let es_pid = (((data[i + 1] & 0x1F) as u16) << 8) | data[i + 2] as u16;
253 let es_info_len = (((data[i + 3] & 0x0F) as usize) << 8) | data[i + 4] as usize;
254 self.streams.insert(es_pid, StreamType::from_byte(st));
255 i += 5 + es_info_len;
256 }
257 }
258
259 fn push_pes(&mut self, pid: u16, payload: &[u8], pusi: bool, out: &mut Vec<PesPacket>) {
260 let stream_type = *self.streams.get(&pid).unwrap_or(&StreamType::Unknown(0));
261
262 if pusi {
263 if let Some(buf) = self.pes_bufs.get_mut(&pid) {
264 if buf.started && !buf.buf.is_empty() {
265 if let Some(pkt) = Self::finish_pes(pid, buf) {
266 out.push(pkt);
267 }
268 }
269 }
270 let entry = self.pes_bufs.entry(pid).or_insert_with(|| PesBuffer {
271 stream_type,
272 buf: Vec::with_capacity(64 * 1024),
273 started: false,
274 });
275 entry.buf.clear();
276 entry.buf.extend_from_slice(payload);
277 entry.started = true;
278 entry.stream_type = stream_type;
279 } else if let Some(buf) = self.pes_bufs.get_mut(&pid) {
280 if buf.started {
281 buf.extend(payload);
282 }
283 }
284 }
285
286 fn finish_pes(pid: u16, buf: &mut PesBuffer) -> Option<PesPacket> {
287 let data = &buf.buf;
288 if data.len() < 9 || data[0] != 0 || data[1] != 0 || data[2] != 1 {
289 return None;
290 }
291 let pes_packet_length = ((data[4] as usize) << 8) | data[5] as usize;
292 let header_data_len = data[8] as usize;
293 let es_start = 9 + header_data_len;
294 if es_start > data.len() {
295 return None;
296 }
297 let flags = data[7];
298 let pts_flag = flags & 0x80 != 0;
299 let dts_flag = flags & 0x40 != 0;
300
301 let pts = if pts_flag && header_data_len >= 5 {
302 Some(parse_ts_timestamp(&data[9..14]))
303 } else {
304 None
305 };
306 let dts = if dts_flag && header_data_len >= 10 {
307 Some(parse_ts_timestamp(&data[14..19]))
308 } else {
309 None
310 };
311
312 let es_end = if pes_packet_length > 0 {
317 (6 + pes_packet_length).min(data.len())
318 } else {
319 data.len()
320 };
321 let payload = data[es_start..es_end].to_vec();
322 if payload.is_empty() {
323 return None;
324 }
325
326 Some(PesPacket {
327 pid,
328 stream_type: buf.stream_type,
329 pts,
330 dts,
331 payload,
332 })
333 }
334}
335
336impl PesBuffer {
337 fn extend(&mut self, data: &[u8]) {
338 self.buf.extend_from_slice(data);
339 }
340}
341
342fn parse_ts_timestamp(b: &[u8]) -> u64 {
347 let a = ((b[0] as u64 >> 1) & 0x07) << 30;
348 let bc = ((b[1] as u64) << 7 | (b[2] as u64 >> 1)) << 15;
349 let de = (b[3] as u64) << 7 | (b[4] as u64 >> 1);
350 a | bc | de
351}
352
353#[cfg(test)]
354mod tests {
355 use super::*;
356
357 fn make_ts_packet(pid: u16, pusi: bool, payload: &[u8]) -> [u8; 188] {
358 let mut pkt = [0xFFu8; 188];
359 pkt[0] = SYNC_BYTE;
360 pkt[1] = if pusi { 0x40 } else { 0x00 } | ((pid >> 8) as u8 & 0x1F);
361 pkt[2] = pid as u8;
362 pkt[3] = 0x10; let copy_len = payload.len().min(184);
364 pkt[4..4 + copy_len].copy_from_slice(&payload[..copy_len]);
365 pkt
367 }
368
369 fn minimal_pat(pmt_pid: u16) -> Vec<u8> {
370 let mut data = vec![
374 0x00, 0x00, 0xB0, 0x0D, 0x00, 0x01, 0xC1, 0x00, 0x00, 0x00, 0x01, ];
382 data.push(0xE0 | ((pmt_pid >> 8) as u8 & 0x1F));
383 data.push(pmt_pid as u8);
384 data.extend_from_slice(&[0x00; 4]); data
386 }
387
388 fn minimal_pmt(video_pid: u16, audio_pid: u16) -> Vec<u8> {
389 let mut data = vec![
393 0x00, 0x02, 0xB0, 0x17, 0x00, 0x01, 0xC1, 0x00, 0x00, 0xE1, 0x00, 0xF0, 0x00, ];
402 data.push(0x1B); data.push(0xE0 | ((video_pid >> 8) as u8 & 0x1F));
405 data.push(video_pid as u8);
406 data.push(0xF0);
407 data.push(0x00); data.push(0x0F); data.push(0xE0 | ((audio_pid >> 8) as u8 & 0x1F));
411 data.push(audio_pid as u8);
412 data.push(0xF0);
413 data.push(0x00); data.extend_from_slice(&[0x00; 4]); data
416 }
417
418 fn minimal_pes(pts_90k: u64, es_payload: &[u8]) -> Vec<u8> {
419 let pes_len = (3 + 5 + es_payload.len()) as u16;
422 let mut data = vec![
423 0x00,
424 0x00,
425 0x01, 0xE0, (pes_len >> 8) as u8,
428 pes_len as u8,
429 0x80, 0x80, 0x05, ];
433 let pts = pts_90k & 0x1_FFFF_FFFF;
435 data.push(0x21 | ((pts >> 29) as u8 & 0x0E));
436 data.push((pts >> 22) as u8);
437 data.push(0x01 | ((pts >> 14) as u8 & 0xFE));
438 data.push((pts >> 7) as u8);
439 data.push(0x01 | ((pts << 1) as u8 & 0xFE));
440 data.extend_from_slice(es_payload);
441 data
442 }
443
444 #[test]
445 fn demux_discovers_streams_and_yields_pes() {
446 let mut demux = TsDemuxer::new();
447 let video_pid = 0x100;
448 let audio_pid = 0x101;
449 let pmt_pid = 0x1000;
450
451 let pat = make_ts_packet(PAT_PID, true, &minimal_pat(pmt_pid));
453 assert!(demux.feed(&pat).is_empty());
454 assert_eq!(demux.pmt_pid, Some(pmt_pid));
455
456 let pmt = make_ts_packet(pmt_pid, true, &minimal_pmt(video_pid, audio_pid));
458 assert!(demux.feed(&pmt).is_empty());
459 assert_eq!(demux.streams.len(), 2);
460 assert_eq!(demux.streams[&video_pid], StreamType::H264);
461 assert_eq!(demux.streams[&audio_pid], StreamType::Aac);
462
463 let pes = minimal_pes(90_000, b"nalunalunalu");
465 let pkt = make_ts_packet(video_pid, true, &pes);
466 assert!(demux.feed(&pkt).is_empty());
468
469 let pes2 = minimal_pes(180_000, b"nalu2");
471 let pkt2 = make_ts_packet(video_pid, true, &pes2);
472 let packets = demux.feed(&pkt2);
473 assert_eq!(packets.len(), 1);
474 assert_eq!(packets[0].pid, video_pid);
475 assert_eq!(packets[0].stream_type, StreamType::H264);
476 assert_eq!(packets[0].pts, Some(90_000));
477 assert_eq!(packets[0].payload, b"nalunalunalu");
478 }
479
480 #[test]
481 fn sync_recovery_skips_garbage() {
482 let mut demux = TsDemuxer::new();
483 let pmt_pid = 0x1000;
484
485 let mut data = vec![0xDE, 0xAD, 0xBE, 0xEF];
487 data.extend_from_slice(&make_ts_packet(PAT_PID, true, &minimal_pat(pmt_pid)));
488 demux.feed(&data);
489 assert_eq!(demux.pmt_pid, Some(pmt_pid));
490 }
491
492 #[test]
493 fn cross_call_buffering_handles_partial_packets() {
494 let mut demux = TsDemuxer::new();
495 let pmt_pid = 0x1000;
496 let full = make_ts_packet(PAT_PID, true, &minimal_pat(pmt_pid));
497
498 demux.feed(&full[..100]);
500 assert_eq!(demux.pmt_pid, None);
501
502 demux.feed(&full[100..]);
504 assert_eq!(demux.pmt_pid, Some(pmt_pid));
505 }
506
507 #[test]
508 fn parse_ts_timestamp_round_trips() {
509 let pts: u64 = 123_456_789;
510 let mut buf = [0u8; 5];
511 buf[0] = 0x21 | ((pts >> 29) as u8 & 0x0E);
512 buf[1] = (pts >> 22) as u8;
513 buf[2] = 0x01 | ((pts >> 14) as u8 & 0xFE);
514 buf[3] = (pts >> 7) as u8;
515 buf[4] = 0x01 | ((pts << 1) as u8 & 0xFE);
516 assert_eq!(parse_ts_timestamp(&buf), pts);
517 }
518}