#![allow(dead_code)]
#![forbid(unsafe_code)]
use std::cmp::Ordering;
use std::collections::BinaryHeap;
#[derive(Debug, Clone)]
pub struct StreamPacket {
pub stream_index: usize,
pub dts: i64,
pub pts: i64,
pub data: Vec<u8>,
pub is_keyframe: bool,
pub duration: i64,
insertion_seq: u64,
}
impl StreamPacket {
#[must_use]
pub fn new(stream_index: usize, dts: i64, data: Vec<u8>, is_keyframe: bool) -> Self {
Self {
stream_index,
dts,
pts: dts,
data,
is_keyframe,
duration: 0,
insertion_seq: 0,
}
}
#[must_use]
pub fn with_pts(mut self, pts: i64) -> Self {
self.pts = pts;
self
}
#[must_use]
pub fn with_duration(mut self, duration: i64) -> Self {
self.duration = duration;
self
}
#[must_use]
pub fn size(&self) -> usize {
self.data.len()
}
}
#[derive(Debug)]
struct HeapEntry(StreamPacket);
impl PartialEq for HeapEntry {
fn eq(&self, other: &Self) -> bool {
self.0.dts == other.0.dts && self.0.insertion_seq == other.0.insertion_seq
}
}
impl Eq for HeapEntry {}
impl PartialOrd for HeapEntry {
fn partial_cmp(&self, other: &Self) -> Option<Ordering> {
Some(self.cmp(other))
}
}
impl Ord for HeapEntry {
fn cmp(&self, other: &Self) -> Ordering {
other
.0
.dts
.cmp(&self.0.dts)
.then_with(|| other.0.insertion_seq.cmp(&self.0.insertion_seq))
}
}
#[derive(Debug)]
struct StreamBuffer {
stream_index: usize,
packets: Vec<StreamPacket>,
buffered_bytes: usize,
flushed_count: u64,
last_flushed_dts: Option<i64>,
}
impl StreamBuffer {
fn new(stream_index: usize) -> Self {
Self {
stream_index,
packets: Vec::new(),
buffered_bytes: 0,
flushed_count: 0,
last_flushed_dts: None,
}
}
fn push(&mut self, packet: StreamPacket) {
self.buffered_bytes += packet.data.len();
self.packets.push(packet);
}
fn sort_by_dts(&mut self) {
self.packets.sort_by_key(|p| p.dts);
}
fn drain_all(&mut self) -> Vec<StreamPacket> {
self.buffered_bytes = 0;
let count = self.packets.len() as u64;
self.flushed_count += count;
if let Some(last) = self.packets.last() {
self.last_flushed_dts = Some(last.dts);
}
std::mem::take(&mut self.packets)
}
fn drain_up_to_dts(&mut self, max_dts: i64) -> Vec<StreamPacket> {
let split_idx = self.packets.partition_point(|p| p.dts <= max_dts);
let drained: Vec<StreamPacket> = self.packets.drain(..split_idx).collect();
let drained_bytes: usize = drained.iter().map(|p| p.data.len()).sum();
self.buffered_bytes -= drained_bytes;
self.flushed_count += drained.len() as u64;
if let Some(last) = drained.last() {
self.last_flushed_dts = Some(last.dts);
}
drained
}
fn len(&self) -> usize {
self.packets.len()
}
fn is_empty(&self) -> bool {
self.packets.is_empty()
}
fn min_dts(&self) -> Option<i64> {
self.packets.first().map(|p| p.dts)
}
}
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
pub enum FlushStrategy {
PacketCount {
threshold: usize,
},
ByteCount {
threshold: usize,
},
DtsSpan {
max_span: i64,
},
Manual,
}
impl Default for FlushStrategy {
fn default() -> Self {
Self::PacketCount { threshold: 64 }
}
}
#[derive(Debug, Clone, Copy, Default)]
pub struct InterleaveStats {
pub total_packets_in: u64,
pub total_packets_out: u64,
pub total_bytes_in: u64,
pub total_bytes_out: u64,
pub flush_count: u64,
pub peak_buffer_packets: usize,
pub reorder_corrections: u64,
}
#[derive(Debug)]
pub struct ParallelInterleaver {
buffers: Vec<StreamBuffer>,
strategy: FlushStrategy,
insertion_counter: u64,
stats: InterleaveStats,
}
impl ParallelInterleaver {
#[must_use]
pub fn new(stream_count: usize) -> Self {
let buffers = (0..stream_count).map(StreamBuffer::new).collect();
Self {
buffers,
strategy: FlushStrategy::default(),
insertion_counter: 0,
stats: InterleaveStats::default(),
}
}
#[must_use]
pub fn with_strategy(stream_count: usize, strategy: FlushStrategy) -> Self {
let mut interleaver = Self::new(stream_count);
interleaver.strategy = strategy;
interleaver
}
pub fn push(&mut self, mut packet: StreamPacket) {
let idx = packet.stream_index;
if idx >= self.buffers.len() {
return;
}
packet.insertion_seq = self.insertion_counter;
self.insertion_counter += 1;
self.stats.total_packets_in += 1;
self.stats.total_bytes_in += packet.data.len() as u64;
self.buffers[idx].push(packet);
let total_buffered: usize = self.buffers.iter().map(|b| b.len()).sum();
if total_buffered > self.stats.peak_buffer_packets {
self.stats.peak_buffer_packets = total_buffered;
}
}
pub fn push_and_maybe_flush(&mut self, packet: StreamPacket) -> Vec<StreamPacket> {
self.push(packet);
if self.should_flush() {
self.flush_ready()
} else {
Vec::new()
}
}
#[must_use]
pub fn should_flush(&self) -> bool {
match self.strategy {
FlushStrategy::PacketCount { threshold } => {
self.buffers.iter().any(|b| b.len() >= threshold)
}
FlushStrategy::ByteCount { threshold } => {
self.buffers.iter().any(|b| b.buffered_bytes >= threshold)
}
FlushStrategy::DtsSpan { max_span } => {
let min = self
.buffers
.iter()
.filter_map(|b| b.min_dts())
.min();
let max = self
.buffers
.iter()
.filter_map(|b| b.packets.last().map(|p| p.dts))
.max();
if let (Some(lo), Some(hi)) = (min, max) {
(hi - lo) > max_span
} else {
false
}
}
FlushStrategy::Manual => false,
}
}
pub fn flush_ready(&mut self) -> Vec<StreamPacket> {
let safe_dts = self
.buffers
.iter()
.filter(|b| !b.is_empty())
.filter_map(|b| b.packets.last().map(|p| p.dts))
.min();
let Some(safe_dts) = safe_dts else {
return Vec::new();
};
for buf in &mut self.buffers {
buf.sort_by_dts();
}
let mut heap: BinaryHeap<HeapEntry> = BinaryHeap::new();
for buf in &mut self.buffers {
for pkt in buf.drain_up_to_dts(safe_dts) {
heap.push(HeapEntry(pkt));
}
}
let mut output = Vec::with_capacity(heap.len());
while let Some(entry) = heap.pop() {
self.stats.total_packets_out += 1;
self.stats.total_bytes_out += entry.0.data.len() as u64;
output.push(entry.0);
}
self.stats.flush_count += 1;
output
}
pub fn flush_all(&mut self) -> Vec<StreamPacket> {
for buf in &mut self.buffers {
buf.sort_by_dts();
}
let mut heap: BinaryHeap<HeapEntry> = BinaryHeap::new();
for buf in &mut self.buffers {
for pkt in buf.drain_all() {
heap.push(HeapEntry(pkt));
}
}
let mut output = Vec::with_capacity(heap.len());
while let Some(entry) = heap.pop() {
self.stats.total_packets_out += 1;
self.stats.total_bytes_out += entry.0.data.len() as u64;
output.push(entry.0);
}
self.stats.flush_count += 1;
output
}
#[must_use]
pub fn stream_count(&self) -> usize {
self.buffers.len()
}
#[must_use]
pub fn total_buffered(&self) -> usize {
self.buffers.iter().map(|b| b.len()).sum()
}
#[must_use]
pub fn total_buffered_bytes(&self) -> usize {
self.buffers.iter().map(|b| b.buffered_bytes).sum()
}
#[must_use]
pub fn is_empty(&self) -> bool {
self.buffers.iter().all(|b| b.is_empty())
}
#[must_use]
pub fn stream_buffered(&self, stream_index: usize) -> usize {
self.buffers
.get(stream_index)
.map_or(0, |b| b.len())
}
#[must_use]
pub fn stats(&self) -> &InterleaveStats {
&self.stats
}
pub fn reset_stats(&mut self) {
self.stats = InterleaveStats::default();
}
#[must_use]
pub fn dts_span(&self) -> i64 {
let min = self
.buffers
.iter()
.filter_map(|b| b.min_dts())
.min();
let max = self
.buffers
.iter()
.filter_map(|b| b.packets.last().map(|p| p.dts))
.max();
match (min, max) {
(Some(lo), Some(hi)) => hi - lo,
_ => 0,
}
}
#[must_use]
pub fn per_stream_sizes(&self) -> Vec<(usize, usize)> {
self.buffers
.iter()
.map(|b| (b.stream_index, b.len()))
.collect()
}
}
#[cfg(test)]
mod tests {
use super::*;
fn pkt(stream: usize, dts: i64) -> StreamPacket {
StreamPacket::new(stream, dts, vec![0u8; 100], dts == 0)
}
#[test]
fn test_basic_two_stream_interleave() {
let mut il = ParallelInterleaver::new(2);
il.push(pkt(0, 0));
il.push(pkt(1, 10));
il.push(pkt(0, 40));
il.push(pkt(1, 30));
let output = il.flush_all();
assert_eq!(output.len(), 4);
assert_eq!(output[0].dts, 0);
assert_eq!(output[1].dts, 10);
assert_eq!(output[2].dts, 30);
assert_eq!(output[3].dts, 40);
}
#[test]
fn test_three_stream_interleave() {
let mut il = ParallelInterleaver::new(3);
il.push(pkt(0, 100));
il.push(pkt(1, 50));
il.push(pkt(2, 75));
il.push(pkt(0, 200));
il.push(pkt(1, 150));
il.push(pkt(2, 175));
let output = il.flush_all();
assert_eq!(output.len(), 6);
for w in output.windows(2) {
assert!(
w[0].dts <= w[1].dts,
"out of order: {} > {}",
w[0].dts,
w[1].dts
);
}
}
#[test]
fn test_empty_interleaver() {
let mut il = ParallelInterleaver::new(2);
assert!(il.is_empty());
assert_eq!(il.total_buffered(), 0);
let output = il.flush_all();
assert!(output.is_empty());
}
#[test]
fn test_single_stream() {
let mut il = ParallelInterleaver::new(1);
il.push(pkt(0, 300));
il.push(pkt(0, 100));
il.push(pkt(0, 200));
let output = il.flush_all();
assert_eq!(output.len(), 3);
assert_eq!(output[0].dts, 100);
assert_eq!(output[1].dts, 200);
assert_eq!(output[2].dts, 300);
}
#[test]
fn test_out_of_range_stream_dropped() {
let mut il = ParallelInterleaver::new(2);
il.push(pkt(0, 0));
il.push(pkt(5, 10)); il.push(pkt(1, 20));
let output = il.flush_all();
assert_eq!(output.len(), 2);
}
#[test]
fn test_flush_ready_partial() {
let mut il = ParallelInterleaver::with_strategy(
2,
FlushStrategy::Manual,
);
il.push(pkt(0, 0));
il.push(pkt(0, 40));
il.push(pkt(0, 80));
il.push(pkt(1, 10));
il.push(pkt(1, 30));
let partial = il.flush_ready();
assert_eq!(partial.len(), 3);
assert_eq!(partial[0].dts, 0);
assert_eq!(partial[1].dts, 10);
assert_eq!(partial[2].dts, 30);
assert_eq!(il.total_buffered(), 2);
}
#[test]
fn test_stats_tracking() {
let mut il = ParallelInterleaver::new(2);
il.push(pkt(0, 0));
il.push(pkt(1, 10));
il.push(pkt(0, 20));
assert_eq!(il.stats().total_packets_in, 3);
assert_eq!(il.stats().total_bytes_in, 300);
let output = il.flush_all();
assert_eq!(output.len(), 3);
assert_eq!(il.stats().total_packets_out, 3);
assert_eq!(il.stats().flush_count, 1);
}
#[test]
fn test_packet_count_flush_strategy() {
let mut il = ParallelInterleaver::with_strategy(
2,
FlushStrategy::PacketCount { threshold: 3 },
);
il.push(pkt(0, 0));
il.push(pkt(0, 10));
assert!(!il.should_flush());
il.push(pkt(0, 20));
assert!(il.should_flush()); }
#[test]
fn test_byte_count_flush_strategy() {
let mut il = ParallelInterleaver::with_strategy(
2,
FlushStrategy::ByteCount { threshold: 250 },
);
il.push(pkt(0, 0)); il.push(pkt(0, 10)); assert!(!il.should_flush());
il.push(pkt(0, 20)); assert!(il.should_flush());
}
#[test]
fn test_dts_span_flush_strategy() {
let mut il = ParallelInterleaver::with_strategy(
2,
FlushStrategy::DtsSpan { max_span: 50 },
);
il.push(pkt(0, 0));
il.push(pkt(1, 10));
assert!(!il.should_flush()); il.push(pkt(0, 60));
assert!(il.should_flush()); }
#[test]
fn test_push_and_maybe_flush() {
let mut il = ParallelInterleaver::with_strategy(
1,
FlushStrategy::PacketCount { threshold: 2 },
);
let out1 = il.push_and_maybe_flush(pkt(0, 0));
assert!(out1.is_empty());
let out2 = il.push_and_maybe_flush(pkt(0, 10));
assert!(!out2.is_empty());
}
#[test]
fn test_per_stream_sizes() {
let mut il = ParallelInterleaver::new(3);
il.push(pkt(0, 0));
il.push(pkt(0, 10));
il.push(pkt(1, 5));
let sizes = il.per_stream_sizes();
assert_eq!(sizes.len(), 3);
assert_eq!(sizes[0], (0, 2));
assert_eq!(sizes[1], (1, 1));
assert_eq!(sizes[2], (2, 0));
}
#[test]
fn test_dts_span() {
let mut il = ParallelInterleaver::new(2);
assert_eq!(il.dts_span(), 0);
il.push(pkt(0, 100));
il.push(pkt(1, 300));
assert_eq!(il.dts_span(), 200);
}
#[test]
fn test_stream_packet_builder() {
let pkt = StreamPacket::new(0, 100, vec![1, 2, 3], true)
.with_pts(200)
.with_duration(50);
assert_eq!(pkt.dts, 100);
assert_eq!(pkt.pts, 200);
assert_eq!(pkt.duration, 50);
assert_eq!(pkt.size(), 3);
assert!(pkt.is_keyframe);
}
#[test]
fn test_stable_ordering_equal_dts() {
let mut il = ParallelInterleaver::new(2);
il.push(StreamPacket::new(0, 100, vec![1], false));
il.push(StreamPacket::new(1, 100, vec![2], false));
il.push(StreamPacket::new(0, 100, vec![3], false));
let output = il.flush_all();
assert_eq!(output.len(), 3);
assert_eq!(output[0].data, vec![1]);
assert_eq!(output[1].data, vec![2]);
assert_eq!(output[2].data, vec![3]);
}
#[test]
fn test_reset_stats() {
let mut il = ParallelInterleaver::new(1);
il.push(pkt(0, 0));
il.flush_all();
assert_eq!(il.stats().total_packets_in, 1);
il.reset_stats();
assert_eq!(il.stats().total_packets_in, 0);
assert_eq!(il.stats().total_packets_out, 0);
}
#[test]
fn test_total_buffered_bytes() {
let mut il = ParallelInterleaver::new(2);
il.push(StreamPacket::new(0, 0, vec![0; 50], true));
il.push(StreamPacket::new(1, 10, vec![0; 75], true));
assert_eq!(il.total_buffered_bytes(), 125);
}
}