use crate::packet;
use crate::packet::TransportScramblingControl;
use crate::psi;
use crate::psi::pat;
use crate::psi::pmt::PmtSection;
use crate::psi::pmt::StreamInfo;
use crate::StreamType;
use log::warn;
use std::marker;
pub trait PacketFilter {
type Ctx: DemuxContext;
fn consume(&mut self, ctx: &mut Self::Ctx, pk: &packet::Packet<'_>);
}
pub struct NullPacketFilter<Ctx: DemuxContext> {
phantom: marker::PhantomData<Ctx>,
}
impl<Ctx: DemuxContext> Default for NullPacketFilter<Ctx> {
fn default() -> NullPacketFilter<Ctx> {
NullPacketFilter {
phantom: marker::PhantomData,
}
}
}
impl<Ctx: DemuxContext> PacketFilter for NullPacketFilter<Ctx> {
type Ctx = Ctx;
fn consume(&mut self, _ctx: &mut Self::Ctx, _pk: &packet::Packet<'_>) {
}
}
#[macro_export]
macro_rules! demux_context {
($name:ident, $filter:ty) => {
pub struct $name {
changeset: $crate::demultiplex::FilterChangeset<$filter>,
}
impl $name {
pub fn new() -> Self {
$name {
changeset: $crate::demultiplex::FilterChangeset::default(),
}
}
}
impl $crate::demultiplex::DemuxContext for $name {
type F = $filter;
fn filter_changeset(&mut self) -> &mut $crate::demultiplex::FilterChangeset<Self::F> {
&mut self.changeset
}
fn construct(&mut self, req: $crate::demultiplex::FilterRequest<'_, '_>) -> Self::F {
self.do_construct(req)
}
}
};
}
#[macro_export]
macro_rules! packet_filter_switch {
(
$name:ident<$ctx:ty> {
$( $case_name:ident : $t:ty ),*,
}
) => {
pub enum $name {
$( $case_name($t), )*
}
impl $crate::demultiplex::PacketFilter for $name {
type Ctx = $ctx;
#[inline(always)]
fn consume(&mut self, ctx: &mut $ctx, pk: &$crate::packet::Packet<'_>) {
match self {
$( &mut $name::$case_name(ref mut f) => f.consume(ctx, pk), )*
}
}
}
}
}
struct Filters<F: PacketFilter> {
filters_by_pid: Vec<Option<F>>,
}
impl<F: PacketFilter> Default for Filters<F> {
fn default() -> Filters<F> {
Filters {
filters_by_pid: vec![],
}
}
}
impl<F: PacketFilter> Filters<F> {
pub fn contains(&self, pid: packet::Pid) -> bool {
usize::from(pid) < self.filters_by_pid.len()
&& self.filters_by_pid[usize::from(pid)].is_some()
}
pub fn get(&mut self, pid: packet::Pid) -> Option<&mut F> {
if usize::from(pid) >= self.filters_by_pid.len() {
None
} else {
self.filters_by_pid[usize::from(pid)].as_mut()
}
}
pub fn insert(&mut self, pid: packet::Pid, filter: F) {
let diff = usize::from(pid) as isize - self.filters_by_pid.len() as isize;
if diff >= 0 {
for _ in 0..=diff {
self.filters_by_pid.push(None);
}
}
self.filters_by_pid[usize::from(pid)] = Some(filter);
}
pub fn remove(&mut self, pid: packet::Pid) {
if usize::from(pid) < self.filters_by_pid.len() {
self.filters_by_pid[usize::from(pid)] = None;
}
}
}
pub enum FilterChange<F: PacketFilter> {
Insert(packet::Pid, F),
Remove(packet::Pid),
}
impl<F: PacketFilter> FilterChange<F> {
fn apply(self, filters: &mut Filters<F>) {
match self {
FilterChange::Insert(pid, filter) => filters.insert(pid, filter),
FilterChange::Remove(pid) => filters.remove(pid),
};
}
}
impl<F: PacketFilter> std::fmt::Debug for FilterChange<F> {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> Result<(), std::fmt::Error> {
match *self {
FilterChange::Insert(pid, _) => write!(f, "FilterChange::Insert {{ {:?}, ... }}", pid),
FilterChange::Remove(pid) => write!(f, "FilterChange::Remove {{ {:?}, ... }}", pid),
}
}
}
#[derive(Debug)]
pub struct FilterChangeset<F: PacketFilter> {
updates: Vec<FilterChange<F>>,
}
impl<F: PacketFilter> Default for FilterChangeset<F> {
fn default() -> FilterChangeset<F> {
FilterChangeset {
updates: Vec::new(),
}
}
}
impl<F: PacketFilter> FilterChangeset<F> {
pub fn insert(&mut self, pid: packet::Pid, filter: F) {
self.updates.push(FilterChange::Insert(pid, filter))
}
pub fn remove(&mut self, pid: packet::Pid) {
self.updates.push(FilterChange::Remove(pid))
}
fn apply(&mut self, filters: &mut Filters<F>) {
for update in self.updates.drain(..) {
update.apply(filters);
}
}
pub fn is_empty(&self) -> bool {
self.updates.is_empty()
}
}
impl<F: PacketFilter> std::iter::IntoIterator for FilterChangeset<F> {
type Item = FilterChange<F>;
type IntoIter = std::vec::IntoIter<FilterChange<F>>;
fn into_iter(self) -> std::vec::IntoIter<FilterChange<F>> {
self.updates.into_iter()
}
}
#[derive(Debug)]
pub enum FilterRequest<'a, 'buf> {
ByPid(packet::Pid),
ByStream {
program_pid: packet::Pid,
stream_type: StreamType,
pmt: &'a PmtSection<'buf>,
stream_info: &'a StreamInfo<'buf>,
},
Pmt {
pid: packet::Pid,
program_number: u16,
},
Nit {
pid: packet::Pid,
},
}
struct PmtProcessor<Ctx: DemuxContext> {
pid: packet::Pid,
program_number: u16,
filters_registered: fixedbitset::FixedBitSet,
phantom: marker::PhantomData<Ctx>,
}
impl<Ctx: DemuxContext> PmtProcessor<Ctx> {
pub fn new(pid: packet::Pid, program_number: u16) -> PmtProcessor<Ctx> {
PmtProcessor {
pid,
program_number,
filters_registered: fixedbitset::FixedBitSet::with_capacity(packet::Pid::PID_COUNT),
phantom: marker::PhantomData,
}
}
fn new_table(
&mut self,
ctx: &mut Ctx,
header: &psi::SectionCommonHeader,
_table_syntax_header: &psi::TableSyntaxHeader<'_>,
sect: &PmtSection<'_>,
) {
if 0x02 != header.table_id {
warn!(
"[PMT {:?} program:{}] Expected PMT to have table id 0x2, but got {:#x}",
self.pid, self.program_number, header.table_id
);
return;
}
let mut pids_seen = fixedbitset::FixedBitSet::with_capacity(packet::Pid::PID_COUNT);
for stream_info in sect.streams() {
let pes_packet_consumer = ctx.construct(FilterRequest::ByStream {
program_pid: self.pid,
stream_type: stream_info.stream_type(),
pmt: §,
stream_info: &stream_info,
});
ctx.filter_changeset()
.insert(stream_info.elementary_pid(), pes_packet_consumer);
pids_seen.insert(usize::from(stream_info.elementary_pid()));
self.filters_registered
.insert(usize::from(stream_info.elementary_pid()));
}
self.remove_outdated(ctx, pids_seen);
}
fn remove_outdated(&mut self, ctx: &mut Ctx, pids_seen: fixedbitset::FixedBitSet) {
for pid in self.filters_registered.difference(&pids_seen) {
ctx.filter_changeset().remove(packet::Pid::new(pid as u16));
}
self.filters_registered = pids_seen;
}
}
impl<Ctx: DemuxContext> psi::WholeSectionSyntaxPayloadParser for PmtProcessor<Ctx> {
type Context = Ctx;
fn section<'a>(
&mut self,
ctx: &mut Self::Context,
header: &psi::SectionCommonHeader,
table_syntax_header: &psi::TableSyntaxHeader<'_>,
data: &'a [u8],
) {
let start = psi::SectionCommonHeader::SIZE + psi::TableSyntaxHeader::SIZE;
let end = data.len() - 4;
match PmtSection::from_bytes(&data[start..end]) {
Ok(sect) => self.new_table(ctx, header, table_syntax_header, §),
Err(e) => warn!(
"[PMT {:?} program:{}] problem reading data: {:?}",
self.pid, self.program_number, e
),
}
}
}
#[derive(Debug)]
pub enum DemuxError {
NotEnoughData {
field: &'static str,
expected: usize,
actual: usize,
},
}
pub struct PmtPacketFilter<Ctx: DemuxContext + 'static> {
pmt_section_packet_consumer: psi::SectionPacketConsumer<
psi::SectionSyntaxSectionProcessor<
psi::DedupSectionSyntaxPayloadParser<
psi::BufferSectionSyntaxParser<
psi::CrcCheckWholeSectionSyntaxPayloadParser<PmtProcessor<Ctx>>,
>,
>,
>,
>,
}
impl<Ctx: DemuxContext> PmtPacketFilter<Ctx> {
pub fn new(pid: packet::Pid, program_number: u16) -> PmtPacketFilter<Ctx> {
let pmt_proc = PmtProcessor::new(pid, program_number);
PmtPacketFilter {
pmt_section_packet_consumer: psi::SectionPacketConsumer::new(
psi::SectionSyntaxSectionProcessor::new(psi::DedupSectionSyntaxPayloadParser::new(
psi::BufferSectionSyntaxParser::new(
psi::CrcCheckWholeSectionSyntaxPayloadParser::new(pmt_proc),
),
)),
),
}
}
}
impl<Ctx: DemuxContext> PacketFilter for PmtPacketFilter<Ctx> {
type Ctx = Ctx;
fn consume(&mut self, ctx: &mut Self::Ctx, pk: &packet::Packet<'_>) {
self.pmt_section_packet_consumer.consume(ctx, pk);
}
}
struct PatProcessor<Ctx: DemuxContext> {
filters_registered: fixedbitset::FixedBitSet,
phantom: marker::PhantomData<Ctx>,
}
impl<Ctx: DemuxContext> Default for PatProcessor<Ctx> {
fn default() -> PatProcessor<Ctx> {
PatProcessor {
filters_registered: fixedbitset::FixedBitSet::with_capacity(packet::Pid::PID_COUNT),
phantom: marker::PhantomData,
}
}
}
impl<Ctx: DemuxContext> PatProcessor<Ctx> {
fn new_table(
&mut self,
ctx: &mut Ctx,
header: &psi::SectionCommonHeader,
_table_syntax_header: &psi::TableSyntaxHeader<'_>,
sect: &pat::PatSection<'_>,
) {
if 0x00 != header.table_id {
warn!(
"Expected PAT to have table id 0x0, but got {:#x}",
header.table_id
);
return;
}
let mut pids_seen = fixedbitset::FixedBitSet::with_capacity(packet::Pid::PID_COUNT);
for desc in sect.programs() {
let filter = match desc {
pat::ProgramDescriptor::Program {
program_number,
pid,
} => ctx.construct(FilterRequest::Pmt {
pid,
program_number,
}),
pat::ProgramDescriptor::Network { pid } => {
ctx.construct(FilterRequest::Nit { pid })
}
};
ctx.filter_changeset().insert(desc.pid(), filter);
pids_seen.insert(usize::from(desc.pid()));
self.filters_registered.insert(usize::from(desc.pid()));
}
self.remove_outdated(ctx, pids_seen);
}
fn remove_outdated(&mut self, ctx: &mut Ctx, pids_seen: fixedbitset::FixedBitSet) {
for pid in self.filters_registered.difference(&pids_seen) {
ctx.filter_changeset().remove(packet::Pid::new(pid as u16));
}
self.filters_registered = pids_seen;
}
}
impl<Ctx: DemuxContext> psi::WholeSectionSyntaxPayloadParser for PatProcessor<Ctx> {
type Context = Ctx;
fn section<'a>(
&mut self,
ctx: &mut Self::Context,
header: &psi::SectionCommonHeader,
table_syntax_header: &psi::TableSyntaxHeader<'_>,
data: &'a [u8],
) {
let start = psi::SectionCommonHeader::SIZE + psi::TableSyntaxHeader::SIZE;
let end = data.len() - 4;
self.new_table(
ctx,
header,
table_syntax_header,
&pat::PatSection::new(&data[start..end]),
);
}
}
pub trait DemuxContext: Sized {
type F: PacketFilter<Ctx = Self>;
fn filter_changeset(&mut self) -> &mut FilterChangeset<Self::F>;
fn construct(&mut self, req: FilterRequest<'_, '_>) -> Self::F;
}
pub struct PatPacketFilter<Ctx: DemuxContext> {
pat_section_packet_consumer: psi::SectionPacketConsumer<
psi::SectionSyntaxSectionProcessor<
psi::DedupSectionSyntaxPayloadParser<
psi::BufferSectionSyntaxParser<
psi::CrcCheckWholeSectionSyntaxPayloadParser<PatProcessor<Ctx>>,
>,
>,
>,
>,
}
impl<Ctx: DemuxContext> Default for PatPacketFilter<Ctx> {
fn default() -> PatPacketFilter<Ctx> {
let pat_proc = PatProcessor::default();
PatPacketFilter {
pat_section_packet_consumer: psi::SectionPacketConsumer::new(
psi::SectionSyntaxSectionProcessor::new(psi::DedupSectionSyntaxPayloadParser::new(
psi::BufferSectionSyntaxParser::new(
psi::CrcCheckWholeSectionSyntaxPayloadParser::new(pat_proc),
),
)),
),
}
}
}
impl<Ctx: DemuxContext> PacketFilter for PatPacketFilter<Ctx> {
type Ctx = Ctx;
fn consume(&mut self, ctx: &mut Self::Ctx, pk: &packet::Packet<'_>) {
self.pat_section_packet_consumer.consume(ctx, pk);
}
}
pub struct Demultiplex<Ctx: DemuxContext> {
processor_by_pid: Filters<Ctx::F>,
}
impl<Ctx: DemuxContext> Demultiplex<Ctx> {
pub fn new(ctx: &mut Ctx) -> Demultiplex<Ctx> {
let mut result = Demultiplex {
processor_by_pid: Filters::default(),
};
result.processor_by_pid.insert(
psi::pat::PAT_PID,
ctx.construct(FilterRequest::ByPid(psi::pat::PAT_PID)),
);
result
}
pub fn push(&mut self, ctx: &mut Ctx, buf: &[u8]) {
let mut itr = buf
.chunks_exact(packet::Packet::SIZE)
.map(packet::Packet::try_new);
let mut pk = if let Some(Some(p)) = itr.next() {
p
} else {
return;
};
'outer: loop {
let this_pid = pk.pid();
if !self.processor_by_pid.contains(this_pid) {
self.add_pid_filter(ctx, this_pid);
};
let this_proc = self.processor_by_pid.get(this_pid).unwrap();
'inner: loop {
if pk.transport_error_indicator() {
warn!("{:?} transport_error_indicator", pk.pid());
} else if pk.transport_scrambling_control()
!= TransportScramblingControl::NotScrambled
{
warn!(
"{:?} dropping scrambled packet {:?}",
pk.pid(),
pk.transport_scrambling_control()
);
} else {
this_proc.consume(ctx, &pk);
if !ctx.filter_changeset().is_empty() {
break 'inner;
}
}
pk = if let Some(Some(p)) = itr.next() {
p
} else {
break 'outer;
};
if pk.pid() != this_pid {
continue 'outer;
}
}
if !ctx.filter_changeset().is_empty() {
ctx.filter_changeset().apply(&mut self.processor_by_pid);
}
debug_assert!(ctx.filter_changeset().is_empty());
pk = if let Some(Some(p)) = itr.next() {
p
} else {
break 'outer;
};
}
}
fn add_pid_filter(&mut self, ctx: &mut Ctx, this_pid: packet::Pid) {
let filter = ctx.construct(FilterRequest::ByPid(this_pid));
self.processor_by_pid.insert(this_pid, filter);
}
}
#[cfg(test)]
pub(crate) mod test {
use bitstream_io::{BitWriter, BE};
use hex_literal::*;
use std::io;
use crate::demultiplex;
use crate::packet;
use crate::psi;
use crate::psi::WholeSectionSyntaxPayloadParser;
use bitstream_io::BigEndian;
packet_filter_switch! {
NullFilterSwitch<NullDemuxContext> {
Pat: demultiplex::PatPacketFilter<NullDemuxContext>,
Pmt: demultiplex::PmtPacketFilter<NullDemuxContext>,
Nul: demultiplex::NullPacketFilter<NullDemuxContext>,
}
}
demux_context!(NullDemuxContext, NullFilterSwitch);
impl NullDemuxContext {
fn do_construct(&mut self, req: demultiplex::FilterRequest<'_, '_>) -> NullFilterSwitch {
match req {
demultiplex::FilterRequest::ByPid(psi::pat::PAT_PID) => {
NullFilterSwitch::Pat(demultiplex::PatPacketFilter::default())
}
demultiplex::FilterRequest::ByPid(_) => {
NullFilterSwitch::Nul(demultiplex::NullPacketFilter::default())
}
demultiplex::FilterRequest::ByStream { .. } => {
NullFilterSwitch::Nul(demultiplex::NullPacketFilter::default())
}
demultiplex::FilterRequest::Pmt {
pid,
program_number,
} => NullFilterSwitch::Pmt(demultiplex::PmtPacketFilter::new(pid, program_number)),
demultiplex::FilterRequest::Nit { .. } => {
NullFilterSwitch::Nul(demultiplex::NullPacketFilter::default())
}
}
}
}
#[test]
fn demux_empty() {
let mut ctx = NullDemuxContext::new();
let mut deplex = demultiplex::Demultiplex::new(&mut ctx);
deplex.push(&mut ctx, &[0x0; 0][..]);
}
#[test]
fn pat() {
let buf = hex!("474000150000B00D0001C100000001E1E02D507804FFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFF");
let mut ctx = NullDemuxContext::new();
let mut deplex = demultiplex::Demultiplex::new(&mut ctx);
deplex.push(&mut ctx, &buf[..]);
}
#[test]
fn pat_no_existing_program() {
let mut processor = demultiplex::PatProcessor::default();
let section = vec![
0, 0, 0,
0x0D, 0x00, 0b00000001, 0xC1, 0x00,
0, 1,
0, 101,
0, 0, 0, 0,
];
let header = psi::SectionCommonHeader::new(§ion[..psi::SectionCommonHeader::SIZE]);
let table_syntax_header =
psi::TableSyntaxHeader::new(§ion[psi::SectionCommonHeader::SIZE..]);
let mut ctx = NullDemuxContext::new();
processor.section(&mut ctx, &header, &table_syntax_header, §ion[..]);
let mut changes = ctx.changeset.updates.into_iter();
if let Some(demultiplex::FilterChange::Insert(pid, _)) = changes.next() {
assert_eq!(packet::Pid::new(101), pid);
} else {
panic!();
}
}
#[test]
fn pat_remove_existing_program() {
let mut ctx = NullDemuxContext::new();
let mut processor = demultiplex::PatProcessor::default();
{
let section = vec![
0, 0, 0,
0x0D, 0x00, 0b00000001, 0xC1, 0x00,
0, 1,
0, 101,
0, 0, 0, 0,
];
let header = psi::SectionCommonHeader::new(§ion[..psi::SectionCommonHeader::SIZE]);
let table_syntax_header =
psi::TableSyntaxHeader::new(§ion[psi::SectionCommonHeader::SIZE..]);
processor.section(&mut ctx, &header, &table_syntax_header, §ion[..]);
}
ctx.changeset.updates.clear();
{
let section = vec![
0, 0, 0,
0x0D, 0x00, 0b00000011, 0xC1, 0x00,
0, 0, 0, 0,
];
let header = psi::SectionCommonHeader::new(§ion[..psi::SectionCommonHeader::SIZE]);
let table_syntax_header =
psi::TableSyntaxHeader::new(§ion[psi::SectionCommonHeader::SIZE..]);
processor.section(&mut ctx, &header, &table_syntax_header, §ion[..]);
}
let mut changes = ctx.changeset.updates.into_iter();
if let Some(demultiplex::FilterChange::Remove(pid)) = changes.next() {
assert_eq!(packet::Pid::new(101), pid);
} else {
panic!();
}
}
pub(crate) fn make_test_data<F>(builder: F) -> Vec<u8>
where
F: Fn(&mut BitWriter<Vec<u8>, BE>) -> Result<(), io::Error>,
{
let data: Vec<u8> = Vec::new();
let mut w = BitWriter::endian(data, BigEndian);
builder(&mut w).unwrap();
w.into_writer()
}
#[test]
fn pmt_new_stream() {
let pid = packet::Pid::new(101);
let program_number = 1001;
let mut processor = demultiplex::PmtProcessor::new(pid, program_number);
let section = make_test_data(|w| {
w.write(8, 0x02)?;
w.write_bit(true)?;
w.write_bit(false)?;
w.write(2, 3)?;
w.write(12, 20)?;
w.write(16, 0)?;
w.write(2, 3)?;
w.write(5, 0)?;
w.write(1, 1)?;
w.write(8, 0)?;
w.write(8, 0)?;
w.write(3, 7)?;
w.write(13, 123)?;
w.write(4, 15)?;
w.write(12, 0)?;
w.write(8, 0)?;
w.write(3, 7)?;
w.write(13, 201)?;
w.write(4, 15)?;
w.write(12, 6)?;
w.write(8, 0)?;
w.write(8, 1)?;
w.write(8, 0)?;
w.write(8, 0)?;
w.write(8, 1)?;
w.write(8, 0)?;
w.write(32, 0)
});
let header = psi::SectionCommonHeader::new(§ion[..psi::SectionCommonHeader::SIZE]);
let table_syntax_header =
psi::TableSyntaxHeader::new(§ion[psi::SectionCommonHeader::SIZE..]);
let mut ctx = NullDemuxContext::new();
processor.section(&mut ctx, &header, &table_syntax_header, §ion[..]);
let mut changes = ctx.changeset.updates.into_iter();
if let Some(demultiplex::FilterChange::Insert(pid, _)) = changes.next() {
assert_eq!(packet::Pid::new(201), pid);
} else {
panic!();
}
}
}