mod pes;
pub mod scte35;
use async_trait::async_trait;
use oximedia_core::{CodecId, OxiError, OxiResult};
use oximedia_io::MediaSource;
use std::collections::HashMap;
use crate::{Muxer, MuxerConfig, Packet, StreamInfo};
use pes::PesPacketBuilder;
const TS_PACKET_SIZE: usize = 188;
const SYNC_BYTE: u8 = 0x47;
const PAT_PID: u16 = 0x0000;
const PMT_PID: u16 = 0x0100;
const FIRST_ES_PID: u16 = 0x0200;
const PCR_INTERVAL: u64 = 9000;
const PROGRAM_NUMBER: u16 = 1;
const TRANSPORT_STREAM_ID: u16 = 1;
#[derive(Debug, Clone, Copy)]
struct StreamTypeInfo {
stream_type: u8,
can_carry_pcr: bool,
}
impl StreamTypeInfo {
const fn from_codec(codec: CodecId) -> Option<Self> {
match codec {
CodecId::Av1 => Some(Self {
stream_type: 0x85, can_carry_pcr: true,
}),
CodecId::Vp9 => Some(Self {
stream_type: 0x84, can_carry_pcr: true,
}),
CodecId::Vp8 => Some(Self {
stream_type: 0x83, can_carry_pcr: true,
}),
CodecId::Opus => Some(Self {
stream_type: 0x81, can_carry_pcr: false,
}),
CodecId::Flac => Some(Self {
stream_type: 0x82, can_carry_pcr: false,
}),
CodecId::Pcm => Some(Self {
stream_type: 0x80, can_carry_pcr: false,
}),
_ => None,
}
}
}
struct ElementaryStream {
#[allow(dead_code)]
info: StreamInfo,
pid: u16,
stream_type: u8,
continuity_counter: u8,
pes_builder: PesPacketBuilder,
}
impl ElementaryStream {
fn new(info: StreamInfo, pid: u16, stream_type: u8) -> Self {
let pes_builder = PesPacketBuilder::new(info.codec, info.index);
Self {
info,
pid,
stream_type,
continuity_counter: 0,
pes_builder,
}
}
fn next_continuity_counter(&mut self) -> u8 {
let cc = self.continuity_counter;
self.continuity_counter = (self.continuity_counter + 1) & 0x0F;
cc
}
}
pub struct MpegTsMuxer<S: MediaSource> {
sink: S,
config: MuxerConfig,
streams: Vec<ElementaryStream>,
stream_by_index: HashMap<usize, usize>,
pat_continuity_counter: u8,
pmt_continuity_counter: u8,
pcr_pid: Option<u16>,
last_pcr: u64,
header_written: bool,
packets_written: u64,
}
impl<S: MediaSource> MpegTsMuxer<S> {
#[must_use]
pub fn new(sink: S, config: MuxerConfig) -> Self {
Self {
sink,
config,
streams: Vec::new(),
stream_by_index: HashMap::new(),
pat_continuity_counter: 0,
pmt_continuity_counter: 0,
pcr_pid: None,
last_pcr: 0,
header_written: false,
packets_written: 0,
}
}
async fn write_ts_packet(
&mut self,
pid: u16,
payload: &[u8],
payload_unit_start: bool,
continuity_counter: u8,
adaptation_field: Option<&[u8]>,
) -> OxiResult<()> {
let mut packet = [0u8; TS_PACKET_SIZE];
packet[0] = SYNC_BYTE;
packet[1] = if payload_unit_start { 0x40 } else { 0x00 } | ((pid >> 8) as u8 & 0x1F);
packet[2] = (pid & 0xFF) as u8;
let has_adaptation = adaptation_field.is_some();
let has_payload = !payload.is_empty();
let afc = match (has_adaptation, has_payload) {
(false, _) => 0x01, (true, false) => 0x02, (true, true) => 0x03, };
packet[3] = (afc << 4) | (continuity_counter & 0x0F);
let mut offset = 4;
if let Some(af) = adaptation_field {
#[allow(clippy::cast_possible_truncation)]
let af_len = af.len() as u8;
packet[offset] = af_len;
offset += 1;
let copy_len = std::cmp::min(af.len(), TS_PACKET_SIZE - offset);
packet[offset..offset + copy_len].copy_from_slice(&af[..copy_len]);
offset += copy_len;
}
if !payload.is_empty() {
let copy_len = std::cmp::min(payload.len(), TS_PACKET_SIZE - offset);
packet[offset..offset + copy_len].copy_from_slice(&payload[..copy_len]);
offset += copy_len;
}
for byte in &mut packet[offset..] {
*byte = 0xFF;
}
self.sink.write_all(&packet).await?;
self.packets_written += 1;
Ok(())
}
#[allow(clippy::vec_init_then_push)]
async fn write_pat(&mut self) -> OxiResult<()> {
let mut section = Vec::new();
section.push(0x00);
section.push(0xB0);
section.push(0x0D);
section.push((TRANSPORT_STREAM_ID >> 8) as u8);
section.push((TRANSPORT_STREAM_ID & 0xFF) as u8);
section.push(0xC1);
section.push(0x00);
section.push(0x00);
section.push((PROGRAM_NUMBER >> 8) as u8);
section.push((PROGRAM_NUMBER & 0xFF) as u8);
section.push((PMT_PID >> 8) as u8 | 0xE0);
section.push((PMT_PID & 0xFF) as u8);
let crc = Self::compute_crc32(§ion);
section.extend_from_slice(&crc.to_be_bytes());
let mut payload = vec![0x00]; payload.extend_from_slice(§ion);
let cc = self.pat_continuity_counter;
self.pat_continuity_counter = (self.pat_continuity_counter + 1) & 0x0F;
self.write_ts_packet(PAT_PID, &payload, true, cc, None)
.await
}
async fn write_pmt(&mut self) -> OxiResult<()> {
let mut section = Vec::new();
section.push(0x02);
let section_length_pos = section.len();
section.push(0xB0);
section.push(0x00);
section.push((PROGRAM_NUMBER >> 8) as u8);
section.push((PROGRAM_NUMBER & 0xFF) as u8);
section.push(0xC1);
section.push(0x00);
section.push(0x00);
let pcr_pid = self.pcr_pid.unwrap_or(0x1FFF);
section.push((pcr_pid >> 8) as u8 | 0xE0);
section.push((pcr_pid & 0xFF) as u8);
section.push(0xF0);
section.push(0x00);
for stream in &self.streams {
section.push(stream.stream_type);
section.push((stream.pid >> 8) as u8 | 0xE0);
section.push((stream.pid & 0xFF) as u8);
section.push(0xF0);
section.push(0x00);
}
let section_length = section.len() - 3 + 4; #[allow(clippy::cast_possible_truncation)]
{
section[section_length_pos + 1] = ((section_length >> 8) as u8 & 0x0F) | 0xB0;
section[section_length_pos + 2] = (section_length & 0xFF) as u8;
}
let crc = Self::compute_crc32(§ion);
section.extend_from_slice(&crc.to_be_bytes());
let mut payload = vec![0x00]; payload.extend_from_slice(§ion);
let cc = self.pmt_continuity_counter;
self.pmt_continuity_counter = (self.pmt_continuity_counter + 1) & 0x0F;
self.write_ts_packet(PMT_PID, &payload, true, cc, None)
.await
}
fn compute_crc32(data: &[u8]) -> u32 {
let mut crc: u32 = 0xFFFF_FFFF;
for &byte in data {
crc ^= u32::from(byte) << 24;
for _ in 0..8 {
if crc & 0x8000_0000 != 0 {
crc = (crc << 1) ^ 0x04C1_1DB7;
} else {
crc <<= 1;
}
}
}
crc
}
#[allow(clippy::cast_possible_truncation)]
fn encode_pcr(pcr: u64) -> [u8; 6] {
let pcr_base = pcr / 300;
let pcr_ext = (pcr % 300) as u16;
[
((pcr_base >> 25) & 0xFF) as u8,
((pcr_base >> 17) & 0xFF) as u8,
((pcr_base >> 9) & 0xFF) as u8,
((pcr_base >> 1) & 0xFF) as u8,
((((pcr_base & 0x01) << 7) | 0x7E_u64 | ((u64::from(pcr_ext) >> 8) & 0x01_u64)) as u8),
(pcr_ext & 0xFF) as u8,
]
}
#[allow(clippy::too_many_arguments)]
async fn write_pes_packet(
&mut self,
stream_idx: usize,
pes_data: &[u8],
pcr: Option<u64>,
) -> OxiResult<()> {
let pid = self.streams[stream_idx].pid;
let pcr_pid = self.pcr_pid.unwrap_or(0);
let mut offset = 0;
let mut first_packet = true;
while offset < pes_data.len() {
let payload_start = first_packet;
let remaining = pes_data.len() - offset;
let mut adaptation_field = Vec::new();
let mut adaptation_field_data = Vec::new();
if let Some(pcr_val) = pcr {
if first_packet && pid == pcr_pid {
adaptation_field_data.push(0x10);
adaptation_field_data.extend_from_slice(&Self::encode_pcr(pcr_val));
}
}
let adaptation_field_ref = if adaptation_field_data.is_empty() {
None
} else {
adaptation_field = adaptation_field_data;
Some(adaptation_field.as_slice())
};
let header_size = 4;
let af_size = if adaptation_field_ref.is_some() {
1 + adaptation_field.len()
} else {
0
};
let available = TS_PACKET_SIZE - header_size - af_size;
let payload_size = std::cmp::min(remaining, available);
let payload = &pes_data[offset..offset + payload_size];
let cc = self.streams[stream_idx].next_continuity_counter();
self.write_ts_packet(pid, payload, payload_start, cc, adaptation_field_ref)
.await?;
offset += payload_size;
first_packet = false;
}
Ok(())
}
}
#[async_trait]
impl<S: MediaSource> Muxer for MpegTsMuxer<S> {
fn add_stream(&mut self, info: StreamInfo) -> OxiResult<usize> {
if self.header_written {
return Err(OxiError::InvalidData(
"Cannot add streams after header is written".to_string(),
));
}
let stream_type_info = StreamTypeInfo::from_codec(info.codec).ok_or_else(|| {
OxiError::unsupported(format!("Codec {:?} not supported in MPEG-TS", info.codec))
})?;
#[allow(clippy::cast_possible_truncation)]
let pid = FIRST_ES_PID + self.streams.len() as u16;
if self.pcr_pid.is_none() && stream_type_info.can_carry_pcr {
self.pcr_pid = Some(pid);
}
let stream_index = info.index;
let es = ElementaryStream::new(info, pid, stream_type_info.stream_type);
let internal_index = self.streams.len();
self.streams.push(es);
self.stream_by_index.insert(stream_index, internal_index);
Ok(stream_index)
}
async fn write_header(&mut self) -> OxiResult<()> {
if self.header_written {
return Err(OxiError::InvalidData("Header already written".to_string()));
}
if self.streams.is_empty() {
return Err(OxiError::InvalidData(
"No streams added to muxer".to_string(),
));
}
self.write_pat().await?;
self.write_pmt().await?;
self.header_written = true;
Ok(())
}
async fn write_packet(&mut self, packet: &Packet) -> OxiResult<()> {
if !self.header_written {
return Err(OxiError::InvalidData(
"Must write header before packets".to_string(),
));
}
let stream_idx = *self
.stream_by_index
.get(&packet.stream_index)
.ok_or_else(|| {
OxiError::InvalidData(format!("Invalid stream index: {}", packet.stream_index))
})?;
let pts = packet.pts();
let dts = packet.dts();
let pes_builder = self.streams[stream_idx]
.pes_builder
.with_pts(pts)
.with_dts(dts.unwrap_or(pts));
let pes_data = pes_builder.build(&packet.data)?;
#[allow(clippy::cast_sign_loss)]
let current_pcr = pts as u64;
let pcr = if current_pcr >= self.last_pcr + PCR_INTERVAL {
self.last_pcr = current_pcr;
Some(current_pcr)
} else {
None
};
self.write_pes_packet(stream_idx, &pes_data, pcr).await?;
Ok(())
}
async fn write_trailer(&mut self) -> OxiResult<()> {
if !self.header_written {
return Err(OxiError::InvalidData("No header written".to_string()));
}
self.write_pat().await?;
self.write_pmt().await?;
Ok(())
}
fn streams(&self) -> &[StreamInfo] {
&[]
}
fn config(&self) -> &MuxerConfig {
&self.config
}
}
#[cfg(test)]
mod tests {
use super::*;
use oximedia_core::Rational;
use oximedia_io::MemorySource;
#[test]
fn test_encode_pcr() {
let pcr = 90000u64; let encoded = MpegTsMuxer::<MemorySource>::encode_pcr(pcr);
assert_eq!(encoded.len(), 6);
}
#[test]
fn test_compute_crc32() {
let data = vec![0x00, 0xB0, 0x0D, 0x00, 0x01, 0xC1, 0x00, 0x00];
let crc = MpegTsMuxer::<MemorySource>::compute_crc32(&data);
assert!(crc != 0);
}
#[test]
fn test_stream_type_from_codec() {
assert!(StreamTypeInfo::from_codec(CodecId::Av1).is_some());
assert!(StreamTypeInfo::from_codec(CodecId::Opus).is_some());
assert!(StreamTypeInfo::from_codec(CodecId::Mp3).is_none()); }
#[tokio::test]
async fn test_add_stream() {
let source = MemorySource::new(bytes::Bytes::new());
let config = MuxerConfig::new();
let mut muxer = MpegTsMuxer::new(source, config);
let stream_info = StreamInfo::new(0, CodecId::Av1, Rational::new(1, 90000));
let index = muxer
.add_stream(stream_info)
.expect("operation should succeed");
assert_eq!(index, 0);
assert_eq!(muxer.streams.len(), 1);
}
}