use std::{
collections::HashMap,
env,
fs::File,
io::{
Read,
Write,
},
};
use libmpegts::{
mux::{
Multiplexer,
MuxFrame,
MuxService,
MuxStream,
},
psi::{
PAT_PID,
PatSectionRef,
PmtSectionRef,
Psi,
},
slicer::TsSlicer,
ts::PACKET_SIZE,
};
struct DemuxStream {
mux_index: usize,
pes_buffer: Vec<u8>,
is_key_frame: bool,
}
fn parse_pes_header(data: &[u8]) -> Option<(u64, Option<u64>, usize)> {
if data.len() < 9 {
return None;
}
if data[0] != 0x00 || data[1] != 0x00 || data[2] != 0x01 {
return None;
}
let header_data_length = data[8] as usize;
let header_size = 9 + header_data_length;
if data.len() < header_size {
return None;
}
let pts_dts_flags = (data[7] >> 6) & 0x03;
let pts = if pts_dts_flags >= 2 {
if data.len() < 14 {
return None;
}
Some(parse_timestamp(&data[9 .. 14]))
} else {
None
};
let dts = if pts_dts_flags == 3 {
if data.len() < 19 {
return None;
}
Some(parse_timestamp(&data[14 .. 19]))
} else {
None
};
let pts = pts?;
Some((pts, dts, header_size))
}
fn parse_timestamp(buf: &[u8]) -> u64 {
let b0 = buf[0] as u64;
let b1 = buf[1] as u64;
let b2 = buf[2] as u64;
let b3 = buf[3] as u64;
let b4 = buf[4] as u64;
((b0 >> 1) & 0x07) << 30 | b1 << 22 | ((b2 >> 1) & 0x7F) << 15 | b3 << 7 | (b4 >> 1) & 0x7F
}
fn is_av_stream(stream_type: u8) -> bool {
matches!(
stream_type,
0x02 | 0x1B | 0x24 | 0x33 |
0x03 | 0x04 | 0x0F | 0x11 |
0x06 | 0x81 | 0x82 | 0x83 | 0x84 | 0x87
)
}
fn discover_streams(path: &str) -> std::io::Result<(u16, MuxService)> {
let mut file = File::open(path)?;
let mut buffer = [0u8; PACKET_SIZE * 1024];
let mut slicer = TsSlicer::new();
let mut pat_psi = Psi::default();
let mut pmt_psi = Psi::default();
let mut service = MuxService::default();
let mut tsid = 1u16;
let mut pmt_found = false;
'outer: loop {
let n = file.read(&mut buffer)?;
if n == 0 {
break;
}
for packet in slicer.slice(&buffer[.. n]) {
let pid = packet.pid();
if pid == PAT_PID {
if let Some(data) = pat_psi.assemble(&packet) {
if let Ok(pat) = PatSectionRef::try_from(data) {
tsid = pat.transport_stream_id();
for program in pat.programs() {
if let Ok(program) = program {
if program.program_number() != 0 {
service.program_number = program.program_number();
service.pmt_pid = program.pid();
break;
}
}
}
}
}
}
if service.pmt_pid != 0 && pid == service.pmt_pid {
if let Some(data) = pmt_psi.assemble(&packet) {
if let Ok(pmt) = PmtSectionRef::try_from(data) {
service.pcr_pid = pmt.pcr_pid();
for stream in pmt.streams() {
if let Ok(stream) = stream {
if is_av_stream(stream.stream_type()) {
let mut stream_info = MuxStream {
stream_type: stream.stream_type(),
elementary_pid: stream.elementary_pid(),
stream_descriptors: Vec::new(),
};
if let Some(descriptors) = stream.stream_descriptors() {
let mut desc_vec = Vec::new();
for desc in descriptors {
if let Ok(desc) = desc {
desc_vec.extend_from_slice(desc.bytes());
}
}
stream_info.stream_descriptors = desc_vec;
}
service.streams.push(stream_info);
}
}
}
pmt_found = true;
break 'outer;
}
}
}
}
}
if !pmt_found {
return Err(std::io::Error::new(
std::io::ErrorKind::InvalidData,
"PMT not found in input file",
));
}
if service.streams.is_empty() {
return Err(std::io::Error::new(
std::io::ErrorKind::InvalidData,
"no audio/video streams found in PMT",
));
}
Ok((tsid, service))
}
fn flush_pes(demux: &mut DemuxStream, mux: &mut Multiplexer) {
if demux.pes_buffer.is_empty() {
return;
}
let Some((pts, dts, header_size)) = parse_pes_header(&demux.pes_buffer) else {
demux.pes_buffer.clear();
return;
};
let es_data = demux.pes_buffer[header_size ..].to_vec();
if !es_data.is_empty() {
mux.push_frame(
demux.mux_index,
MuxFrame {
data: es_data,
is_key_frame: demux.is_key_frame,
pts_dts: Some((pts, dts).into()),
},
);
}
demux.pes_buffer.clear();
}
fn main() -> std::io::Result<()> {
let args: Vec<String> = env::args().collect();
if args.len() != 3 {
eprintln!("Usage: ts_remux <input.ts> <output.ts>");
std::process::exit(1);
}
let input_path = &args[1];
let output_path = &args[2];
let (tsid, service) = discover_streams(input_path)?;
eprintln!(
"TSID: {}, PNR: {}, PMT PID: {}, PCR PID: {}",
tsid, service.program_number, service.pmt_pid, service.pcr_pid
);
for s in &service.streams {
eprintln!(
" stream type=0x{:02X} pid={}",
s.stream_type, s.elementary_pid
);
}
let mut mux = Multiplexer::new(tsid);
mux.add_service(&service);
let mut demux_map: HashMap<u16, DemuxStream> = HashMap::new();
for stream in &service.streams {
let mux_index = mux.stream_index(stream.elementary_pid).unwrap();
demux_map.insert(
stream.elementary_pid,
DemuxStream {
mux_index,
pes_buffer: Vec::with_capacity(256 * 1024),
is_key_frame: false,
},
);
}
let mut input_file = File::open(input_path)?;
let mut output_file = File::create(output_path)?;
let mut input_buf = [0u8; PACKET_SIZE * 1024];
let mut output_buf = [0u8; PACKET_SIZE * 1024];
let mut slicer = TsSlicer::new();
let mut frame_count: u64 = 0;
loop {
let n = input_file.read(&mut input_buf)?;
if n == 0 {
break;
}
for packet in slicer.slice(&input_buf[.. n]) {
let pid = packet.pid();
let Some(demux) = demux_map.get_mut(&pid) else {
continue;
};
let Some(payload) = packet.payload() else {
continue;
};
if packet.is_payload_start() {
flush_pes(demux, &mut mux);
frame_count += 1;
demux.is_key_frame = packet
.adaptation_field()
.map(|af| af.discontinuity_indicator() || payload.len() > 0)
.unwrap_or(false);
let raw = packet.as_ref();
if raw[3] & 0x20 != 0 && raw[4] > 0 {
demux.is_key_frame = raw[5] & 0x40 != 0;
} else {
demux.is_key_frame = false;
}
}
demux.pes_buffer.extend_from_slice(payload);
}
loop {
let n = mux.drain(&mut output_buf);
if n == 0 {
break;
}
output_file.write_all(&output_buf[.. n])?;
}
}
loop {
let n = mux.drain(&mut output_buf);
if n == 0 {
break;
}
output_file.write_all(&output_buf[.. n])?;
}
eprintln!("Done. {frame_count} frames processed.");
Ok(())
}