#![forbid(unsafe_code)]
use oximedia_core::Rational;
use crate::{Packet, StreamInfo};
#[derive(Debug, Clone)]
pub struct EncodedBlock {
pub track_num: u64,
pub abs_timecode: i64,
pub rel_timecode: i16,
pub is_keyframe: bool,
pub payload: Vec<u8>,
pub stream_index: usize,
pub data_len: usize,
}
#[derive(Debug, Clone)]
pub struct InterleaverConfig {
pub max_pending: usize,
pub parallel_threshold: usize,
}
impl Default for InterleaverConfig {
fn default() -> Self {
Self {
max_pending: 32,
parallel_threshold: 4,
}
}
}
impl InterleaverConfig {
#[must_use]
pub const fn new() -> Self {
Self {
max_pending: 32,
parallel_threshold: 4,
}
}
#[must_use]
pub const fn with_max_pending(mut self, n: usize) -> Self {
self.max_pending = n;
self
}
#[must_use]
pub const fn with_parallel_threshold(mut self, n: usize) -> Self {
self.parallel_threshold = n;
self
}
}
#[derive(Clone)]
struct Pending {
packet: Packet,
track_num: u64,
abs_timecode: i64,
cluster_timecode: i64,
}
pub struct ParallelInterleaver {
config: InterleaverConfig,
pending: Vec<Pending>,
}
impl ParallelInterleaver {
#[must_use]
pub fn new() -> Self {
Self::with_config(InterleaverConfig::default())
}
#[must_use]
pub fn with_config(config: InterleaverConfig) -> Self {
Self {
config,
pending: Vec::with_capacity(config.max_pending),
}
}
#[must_use]
pub fn pending_count(&self) -> usize {
self.pending.len()
}
pub fn push<F>(
&mut self,
packet: Packet,
stream_info: &StreamInfo,
track_num: u64,
cluster_timecode: i64,
to_timecode: F,
) -> Vec<EncodedBlock>
where
F: Fn(i64, Rational) -> i64,
{
let abs_timecode = to_timecode(packet.pts(), stream_info.timebase);
self.pending.push(Pending {
packet,
track_num,
abs_timecode,
cluster_timecode,
});
if self.pending.len() >= self.config.max_pending {
self.flush_inner()
} else {
Vec::new()
}
}
pub fn flush(&mut self) -> Vec<EncodedBlock> {
self.flush_inner()
}
fn flush_inner(&mut self) -> Vec<EncodedBlock> {
if self.pending.is_empty() {
return Vec::new();
}
self.pending
.sort_unstable_by_key(|p| p.abs_timecode);
let items: Vec<Pending> = std::mem::take(&mut self.pending);
let n = items.len();
if n < self.config.parallel_threshold {
items.into_iter().map(encode_pending).collect()
} else {
parallel_encode(items)
}
}
}
impl Default for ParallelInterleaver {
fn default() -> Self {
Self::new()
}
}
fn encode_pending(p: Pending) -> EncodedBlock {
let is_keyframe = p.packet.is_keyframe();
let data = &p.packet.data;
let data_len = data.len();
let rel = (p.abs_timecode - p.cluster_timecode).clamp(i16::MIN as i64, i16::MAX as i64) as i16;
let track_vint = encode_vint(p.track_num);
let mut payload = Vec::with_capacity(track_vint.len() + 3 + data_len);
payload.extend_from_slice(&track_vint);
payload.extend_from_slice(&rel.to_be_bytes());
payload.push(if is_keyframe { 0x80 } else { 0x00 });
payload.extend_from_slice(data);
EncodedBlock {
track_num: p.track_num,
abs_timecode: p.abs_timecode,
rel_timecode: rel,
is_keyframe,
payload,
stream_index: p.packet.stream_index,
data_len,
}
}
fn parallel_encode(items: Vec<Pending>) -> Vec<EncodedBlock> {
let n = items.len();
let mut results: Vec<Option<EncodedBlock>> = (0..n).map(|_| None).collect();
std::thread::scope(|s| {
let pairs: Vec<(&Pending, &mut Option<EncodedBlock>)> =
items.iter().zip(results.iter_mut()).collect();
let handles: Vec<_> = pairs
.into_iter()
.map(|(item, slot)| {
s.spawn(move || {
*slot = Some(encode_pending(item.clone()));
})
})
.collect();
for h in handles {
let _ = h.join();
}
});
results.into_iter().flatten().collect()
}
fn encode_vint(value: u64) -> Vec<u8> {
if value < 0x80 {
vec![0x80 | value as u8]
} else if value < 0x4000 {
vec![0x40 | (value >> 8) as u8, value as u8]
} else if value < 0x20_0000 {
vec![0x20 | (value >> 16) as u8, (value >> 8) as u8, value as u8]
} else if value < 0x1000_0000 {
vec![
0x10 | (value >> 24) as u8,
(value >> 16) as u8,
(value >> 8) as u8,
value as u8,
]
} else {
vec![
0x01,
(value >> 48) as u8,
(value >> 40) as u8,
(value >> 32) as u8,
(value >> 24) as u8,
(value >> 16) as u8,
(value >> 8) as u8,
value as u8,
]
}
}
#[cfg(test)]
mod tests {
use super::*;
use bytes::Bytes;
use oximedia_core::{CodecId, Rational, Timestamp};
use crate::{PacketFlags, StreamInfo};
fn make_packet(stream_index: usize, pts: i64, keyframe: bool, size: usize) -> Packet {
let flags = if keyframe {
PacketFlags::KEYFRAME
} else {
PacketFlags::empty()
};
let data = Bytes::from(vec![0xABu8; size]);
let ts = Timestamp::new(pts, Rational::new(1, 1000));
Packet::new(stream_index, data, ts, flags)
}
fn make_stream(index: usize) -> StreamInfo {
StreamInfo::new(index, CodecId::Vp9, Rational::new(1, 1000))
}
fn identity_timecode(pts: i64, _timebase: Rational) -> i64 {
pts
}
#[test]
fn test_encode_vint_1byte() {
assert_eq!(encode_vint(1), vec![0x81]);
assert_eq!(encode_vint(0), vec![0x80]);
assert_eq!(encode_vint(127), vec![0xFF]);
}
#[test]
fn test_encode_vint_2byte() {
assert_eq!(encode_vint(128), vec![0x40, 0x80]);
}
#[test]
fn test_encode_pending_keyframe_flag() {
let pkt = make_packet(0, 1000, true, 10);
let p = Pending {
packet: pkt,
track_num: 1,
abs_timecode: 1000,
cluster_timecode: 0,
};
let block = encode_pending(p);
assert!(block.is_keyframe);
let flags_offset = encode_vint(1).len() + 2;
assert_eq!(block.payload[flags_offset], 0x80);
}
#[test]
fn test_encode_pending_non_keyframe() {
let pkt = make_packet(0, 0, false, 5);
let p = Pending {
packet: pkt,
track_num: 1,
abs_timecode: 0,
cluster_timecode: 0,
};
let block = encode_pending(p);
assert!(!block.is_keyframe);
let flags_offset = encode_vint(1).len() + 2;
assert_eq!(block.payload[flags_offset], 0x00);
}
#[test]
fn test_encode_pending_rel_timecode() {
let pkt = make_packet(0, 5000, false, 4);
let p = Pending {
packet: pkt,
track_num: 2,
abs_timecode: 5000,
cluster_timecode: 4000,
};
let block = encode_pending(p);
assert_eq!(block.rel_timecode, 1000);
let vint_len = encode_vint(2).len();
let rel_bytes = &block.payload[vint_len..vint_len + 2];
assert_eq!(rel_bytes, &[0x03, 0xE8]);
}
#[test]
fn test_encode_pending_data_appended() {
let pkt = make_packet(0, 0, true, 8);
let p = Pending {
packet: pkt,
track_num: 1,
abs_timecode: 0,
cluster_timecode: 0,
};
let block = encode_pending(p);
assert_eq!(block.data_len, 8);
assert_eq!(block.payload.len(), 1 + 2 + 1 + 8);
}
#[test]
fn test_interleaver_buffers_below_threshold() {
let config = InterleaverConfig::new().with_max_pending(8);
let mut il = ParallelInterleaver::with_config(config);
let stream = make_stream(0);
let pkt = make_packet(0, 0, true, 4);
let result = il.push(pkt, &stream, 1, 0, identity_timecode);
assert!(result.is_empty(), "should buffer, not encode yet");
assert_eq!(il.pending_count(), 1);
}
#[test]
fn test_interleaver_encodes_when_full() {
let config = InterleaverConfig::new()
.with_max_pending(4)
.with_parallel_threshold(2);
let mut il = ParallelInterleaver::with_config(config);
let stream = make_stream(0);
let mut last_result = Vec::new();
for i in 0..4u64 {
let pkt = make_packet(0, i as i64 * 100, i == 0, 4);
last_result = il.push(pkt, &stream, 1, 0, identity_timecode);
}
assert_eq!(last_result.len(), 4);
assert_eq!(il.pending_count(), 0);
}
#[test]
fn test_interleaver_flush_drains_buffer() {
let mut il = ParallelInterleaver::new();
let stream = make_stream(0);
for i in 0..3u64 {
let pkt = make_packet(0, i as i64 * 100, i == 0, 4);
il.push(pkt, &stream, 1, 0, identity_timecode);
}
assert_eq!(il.pending_count(), 3);
let blocks = il.flush();
assert_eq!(blocks.len(), 3);
assert_eq!(il.pending_count(), 0);
}
#[test]
fn test_interleaver_flush_empty_returns_empty() {
let mut il = ParallelInterleaver::new();
let blocks = il.flush();
assert!(blocks.is_empty());
}
#[test]
fn test_interleaver_sorts_by_pts() {
let config = InterleaverConfig::new()
.with_max_pending(3)
.with_parallel_threshold(2);
let mut il = ParallelInterleaver::with_config(config);
let stream = make_stream(0);
let pts_values: &[i64] = &[300, 100, 200];
let mut result = Vec::new();
for &pts in pts_values {
let pkt = make_packet(0, pts, false, 4);
result = il.push(pkt, &stream, 1, 0, identity_timecode);
}
let timecodes: Vec<i64> = result.iter().map(|b| b.abs_timecode).collect();
assert_eq!(timecodes, vec![100, 200, 300]);
}
#[test]
fn test_interleaver_parallel_matches_sequential() {
let n = 20usize;
let stream = make_stream(0);
let mut seq_il = ParallelInterleaver::with_config(
InterleaverConfig::new()
.with_max_pending(n + 1)
.with_parallel_threshold(n + 1), );
let mut par_il = ParallelInterleaver::with_config(
InterleaverConfig::new()
.with_max_pending(n + 1)
.with_parallel_threshold(1), );
let packets: Vec<Packet> = (0..n)
.map(|i| make_packet(0, i as i64 * 50, i == 0, 8))
.collect();
for pkt in &packets {
seq_il.push(pkt.clone(), &stream, 1, 0, identity_timecode);
par_il.push(pkt.clone(), &stream, 1, 0, identity_timecode);
}
let seq_blocks = seq_il.flush();
let par_blocks = par_il.flush();
assert_eq!(seq_blocks.len(), par_blocks.len());
for (s, p) in seq_blocks.iter().zip(par_blocks.iter()) {
assert_eq!(s.payload, p.payload, "payloads differ at timecode {}", s.abs_timecode);
assert_eq!(s.is_keyframe, p.is_keyframe);
assert_eq!(s.rel_timecode, p.rel_timecode);
}
}
#[test]
fn test_interleaver_multi_track() {
let config = InterleaverConfig::new()
.with_max_pending(4)
.with_parallel_threshold(2);
let mut il = ParallelInterleaver::with_config(config);
let video = make_stream(0);
let audio = make_stream(1);
il.push(make_packet(0, 0, true, 100), &video, 1, 0, identity_timecode);
il.push(make_packet(1, 10, false, 20), &audio, 2, 0, identity_timecode);
il.push(make_packet(0, 33, false, 90), &video, 1, 0, identity_timecode);
let blocks = il.push(make_packet(1, 43, false, 20), &audio, 2, 0, identity_timecode);
assert_eq!(blocks.len(), 4);
assert_eq!(blocks[0].abs_timecode, 0);
}
#[test]
fn test_config_defaults() {
let c = InterleaverConfig::default();
assert_eq!(c.max_pending, 32);
assert_eq!(c.parallel_threshold, 4);
}
#[test]
fn test_config_builder() {
let c = InterleaverConfig::new()
.with_max_pending(16)
.with_parallel_threshold(8);
assert_eq!(c.max_pending, 16);
assert_eq!(c.parallel_threshold, 8);
}
}