use std::collections::VecDeque;
use crate::{
pes::{
EsFrame,
PesHeader,
PesPacketizer,
PtsDts,
STREAM_ID_AUDIO,
STREAM_ID_PRIVATE_1,
STREAM_ID_VIDEO,
Timestamp,
},
psi::{
PAT_PID,
PatBuilder,
PatConfig,
PatProgram,
PmtBuilder,
PmtConfig,
PmtStream,
PsiPacketizer,
},
ts::PACKET_SIZE,
};
const PCR_DELAY: Timestamp = Timestamp::new(700 * 90); const PCR_INTERVAL: u64 = 40 * 90; const PSI_INTERVAL: u64 = 500 * 90;
pub struct MuxFrame {
pub data: Vec<u8>,
pub is_key_frame: bool,
pub pts_dts: Option<PtsDts>,
}
#[derive(Debug, Default, Clone, PartialEq, Eq)]
pub struct MuxStream {
pub stream_type: u8,
pub elementary_pid: u16,
pub stream_descriptors: Vec<u8>,
}
#[derive(Debug, Default, Clone, PartialEq, Eq)]
pub struct MuxService {
pub program_number: u16,
pub pmt_pid: u16,
pub pcr_pid: u16,
pub program_descriptors: Vec<u8>,
pub service_descriptors: Vec<u8>,
pub streams: Vec<MuxStream>,
}
#[derive(Clone, Copy)]
struct ActiveFrame {
timestamp: Option<Timestamp>,
pending_key_psi: bool,
pending_key_pcr: bool,
}
struct StreamContext {
pmt_stream: PmtStream,
stream_id: u8,
packetizer: PesPacketizer,
pending: VecDeque<MuxFrame>,
active: Option<ActiveFrame>,
}
impl StreamContext {
fn new(stream: &MuxStream, stream_id: u8) -> Self {
let pid = stream.elementary_pid;
Self {
pmt_stream: PmtStream {
stream_type: stream.stream_type,
elementary_pid: stream.elementary_pid,
stream_descriptors: stream.stream_descriptors.clone(),
},
stream_id,
packetizer: PesPacketizer::new(pid),
pending: VecDeque::new(),
active: None,
}
}
fn push_frame(&mut self, frame: MuxFrame) {
if frame.pts_dts.is_none() {
return;
}
self.pending.push_back(frame);
}
fn load_next_frame(&mut self) {
debug_assert!(self.active.is_none());
let Some(frame) = self.pending.pop_front() else {
return;
};
let timestamp = frame.pts_dts.map(|ts| ts.timestamp());
let mut header = PesHeader::new(self.stream_id).with_data_alignment(frame.is_key_frame);
if let Some(pts_dts) = frame.pts_dts {
header = header.with_pts_dts(pts_dts);
}
let es_frame = EsFrame {
header,
payload: frame.data,
rai: frame.is_key_frame,
};
self.packetizer.set_frame(es_frame);
self.active = Some(ActiveFrame {
timestamp,
pending_key_psi: frame.is_key_frame,
pending_key_pcr: frame.is_key_frame && timestamp.is_some(),
});
}
}
struct ServiceContext {
program_number: u16,
pmt_pid: u16,
pcr_pid: u16,
program_descriptors: Vec<u8>,
pmt_packetizer: PsiPacketizer,
pcr_stream_index: usize,
}
impl ServiceContext {
fn new(service: &MuxService, pcr_stream_index: usize) -> Self {
Self {
program_number: service.program_number,
pmt_pid: service.pmt_pid,
pcr_pid: service.pcr_pid,
program_descriptors: service.program_descriptors.clone(),
pmt_packetizer: PsiPacketizer::new(service.pmt_pid),
pcr_stream_index,
}
}
}
#[derive(Clone, Copy)]
enum MultiplexerState {
Idle,
EmitPat,
EmitPmt,
EmitPcr(Timestamp),
EmitFrame(usize),
}
pub struct Multiplexer {
transport_stream_id: u16,
pat_packetizer: PsiPacketizer,
state: MultiplexerState,
psi_dirty: bool,
service: Option<ServiceContext>,
streams: Vec<StreamContext>,
last_pcr_timestamp: Option<Timestamp>,
last_psi_timestamp: Option<Timestamp>,
}
impl Multiplexer {
pub fn new(transport_stream_id: u16) -> Self {
Self {
transport_stream_id,
pat_packetizer: PsiPacketizer::new(PAT_PID),
service: None,
state: MultiplexerState::Idle,
psi_dirty: true,
streams: Vec::new(),
last_pcr_timestamp: None,
last_psi_timestamp: None,
}
}
pub fn add_service(&mut self, service: &MuxService) {
for stream in &service.streams {
let stream_id = match stream.stream_type {
0x02 | 0x1B | 0x24 | 0x33 => self.next_stream_id(STREAM_ID_VIDEO, 0x0F),
0x03 | 0x04 | 0x0F | 0x11 => self.next_stream_id(STREAM_ID_AUDIO, 0x1F),
0x06 | 0x81 | 0x82 | 0x83 | 0x84 | 0x87 => STREAM_ID_PRIVATE_1,
_ => STREAM_ID_PRIVATE_1,
};
self.streams.push(StreamContext::new(stream, stream_id));
}
let pcr_stream_index = self.stream_index(service.pcr_pid).unwrap();
self.psi_dirty = true;
self.service = Some(ServiceContext::new(service, pcr_stream_index));
}
pub fn stream_index(&self, elementary_pid: u16) -> Option<usize> {
self.streams
.iter()
.position(|stream| stream.pmt_stream.elementary_pid == elementary_pid)
}
pub fn push_frame(&mut self, stream_index: usize, frame: MuxFrame) {
if let Some(stream) = self.streams.get_mut(stream_index) {
stream.push_frame(frame);
}
}
pub fn drain(&mut self, buf: &mut [u8]) -> usize {
let (packets, _) = buf.as_chunks_mut::<PACKET_SIZE>();
let mut written = 0;
while written < packets.len() {
match self.state {
MultiplexerState::Idle => {
self.load_frames();
let Some((idx, meta)) = self.select_stream() else {
break;
};
if self.check_psi_state(&meta) {
self.prepare_psi_state(idx, &meta);
} else if self.check_pcr_state(&meta) {
self.prepare_pcr_state(idx, &meta);
} else {
self.state = MultiplexerState::EmitFrame(idx);
}
}
MultiplexerState::EmitPat => {
written += self.emit_pat(&mut packets[written ..]);
}
MultiplexerState::EmitPmt => {
written += self.emit_pmt(&mut packets[written ..]);
}
MultiplexerState::EmitPcr(timestamp) => {
written += self.emit_pcr(timestamp, &mut packets[written ..]);
}
MultiplexerState::EmitFrame(idx) => {
written += self.emit_stream(idx, &mut packets[written ..]);
}
}
}
written * PACKET_SIZE
}
fn check_psi_state(&self, meta: &ActiveFrame) -> bool {
self.psi_dirty
|| meta.pending_key_psi
|| is_timestamp_delta_exceeded(self.last_psi_timestamp, meta.timestamp, PSI_INTERVAL)
}
fn prepare_psi_state(&mut self, idx: usize, meta: &ActiveFrame) {
if self.psi_dirty {
self.rebuild();
self.psi_dirty = false;
} else {
self.pat_packetizer.reset();
if let Some(service) = self.service.as_mut() {
service.pmt_packetizer.reset();
}
}
self.last_psi_timestamp = meta.timestamp;
self.streams[idx].active.as_mut().unwrap().pending_key_psi = false;
self.state = MultiplexerState::EmitPat;
}
fn check_pcr_state(&self, meta: &ActiveFrame) -> bool {
meta.pending_key_pcr
|| is_timestamp_delta_exceeded(self.last_pcr_timestamp, meta.timestamp, PCR_INTERVAL)
}
fn prepare_pcr_state(&mut self, idx: usize, meta: &ActiveFrame) {
self.last_pcr_timestamp = meta.timestamp;
self.streams[idx].active.as_mut().unwrap().pending_key_pcr = false;
self.state = MultiplexerState::EmitPcr(meta.timestamp.unwrap());
}
fn load_frames(&mut self) {
self.streams
.iter_mut()
.filter(|s| s.active.is_none())
.for_each(|s| s.load_next_frame());
}
fn select_stream(&self) -> Option<(usize, ActiveFrame)> {
let mut result = None;
let mut timestamp = None;
for (i, stream) in self.streams.iter().enumerate() {
if let Some(active) = stream.active {
match (timestamp, active.timestamp) {
(None, Some(ts2)) => {
result = Some((i, active));
timestamp = Some(ts2);
}
(Some(ts1), Some(ts2)) if ts2.is_before(ts1) => {
result = Some((i, active));
timestamp = Some(ts2);
}
_ => {}
}
}
}
result
}
fn rebuild(&mut self) {
let Some(service) = self.service.as_mut() else {
return;
};
let pat_sections = PatBuilder::build(PatConfig {
transport_stream_id: self.transport_stream_id,
version: 0,
programs: vec![PatProgram {
program_number: service.program_number,
pid: service.pmt_pid,
}],
});
self.pat_packetizer.set_sections(pat_sections);
let pmt_sections = PmtBuilder::build(PmtConfig {
program_number: service.program_number,
pcr_pid: service.pcr_pid,
version: 0,
program_descriptors: service.program_descriptors.clone(),
streams: self
.streams
.iter()
.map(|stream| stream.pmt_stream.clone())
.collect(),
});
service.pmt_packetizer.set_sections(pmt_sections);
}
fn next_stream_id(&self, base: u8, limit: u8) -> u8 {
let max = base + limit;
let count = self
.streams
.iter()
.filter(|s| s.stream_id >= base && s.stream_id < max)
.count();
base + count as u8
}
fn emit_pat(&mut self, packets: &mut [[u8; PACKET_SIZE]]) -> usize {
let mut written = 0;
while written < packets.len() {
let packet = &mut packets[written];
if self.pat_packetizer.next(packet) {
written += 1;
} else {
self.state = MultiplexerState::EmitPmt;
break;
}
}
written
}
fn emit_pmt(&mut self, packets: &mut [[u8; PACKET_SIZE]]) -> usize {
let Some(service) = self.service.as_mut() else {
self.state = MultiplexerState::Idle;
return 0;
};
let mut written = 0;
while written < packets.len() {
let packet = &mut packets[written];
if service.pmt_packetizer.next(packet) {
written += 1;
} else {
self.state = MultiplexerState::Idle;
break;
}
}
written
}
fn emit_pcr(&mut self, timestamp: Timestamp, packets: &mut [[u8; PACKET_SIZE]]) -> usize {
let Some(service) = self.service.as_ref() else {
self.state = MultiplexerState::Idle;
return 0;
};
let pcr = timestamp.wrapping_sub(PCR_DELAY).value() * 300;
let Some(packet) = packets.get_mut(0) else {
return 0;
};
self.streams[service.pcr_stream_index]
.packetizer
.build_pcr_packet(packet, pcr);
self.state = MultiplexerState::Idle;
1
}
fn emit_stream(&mut self, idx: usize, packets: &mut [[u8; PACKET_SIZE]]) -> usize {
let stream = &mut self.streams[idx];
if stream.active.is_none() {
self.state = MultiplexerState::Idle;
return 0;
}
let mut written = 0;
while written < packets.len() {
let packet = &mut packets[written];
if stream.packetizer.next(packet) {
written += 1;
} else {
stream.active = None;
self.state = MultiplexerState::Idle;
break;
}
}
written
}
}
fn is_timestamp_delta_exceeded(
last: Option<Timestamp>,
current: Option<Timestamp>,
interval: u64,
) -> bool {
match (last, current) {
(_, None) => false,
(Some(last), Some(current)) => current.wrapping_sub(last).value() >= interval,
(None, Some(_)) => true,
}
}