#![allow(clippy::cast_lossless)]
#![allow(clippy::cast_precision_loss)]
#![allow(clippy::cast_possible_truncation)]
#![allow(clippy::cast_sign_loss)]
use std::cmp::Reverse;
use std::collections::BinaryHeap;
#[derive(Debug, Clone, Default, PartialEq, Eq)]
pub struct PacketFlags {
pub keyframe: bool,
pub corrupt: bool,
pub discard: bool,
}
#[derive(Debug, Clone)]
pub struct CodecPacket {
pub pts: u64,
pub dts: u64,
pub duration: u32,
pub time_base: (u32, u32),
pub data: Vec<u8>,
pub flags: PacketFlags,
pub stream_index: u32,
}
impl CodecPacket {
#[must_use]
pub fn pts_secs(&self) -> f64 {
let (num, den) = self.time_base;
if den == 0 {
return 0.0;
}
self.pts as f64 * num as f64 / den as f64
}
#[must_use]
pub fn dts_secs(&self) -> f64 {
let (num, den) = self.time_base;
if den == 0 {
return 0.0;
}
self.dts as f64 * num as f64 / den as f64
}
#[must_use]
pub fn rebase(&self, new_time_base: (u32, u32)) -> Self {
let (old_num, old_den) = self.time_base;
let (new_num, new_den) = new_time_base;
let rescale = |v: u64| -> u64 {
if old_den == 0 || new_num == 0 {
return v;
}
let numerator = v as u128 * old_num as u128 * new_den as u128;
let denominator = old_den as u128 * new_num as u128;
if denominator == 0 {
return v;
}
((numerator + denominator / 2) / denominator) as u64
};
let dur_rescale = |v: u32| -> u32 {
if old_den == 0 || new_num == 0 {
return v;
}
let numerator = v as u128 * old_num as u128 * new_den as u128;
let denominator = old_den as u128 * new_num as u128;
if denominator == 0 {
return v;
}
((numerator + denominator / 2) / denominator).min(u32::MAX as u128) as u32
};
Self {
pts: rescale(self.pts),
dts: rescale(self.dts),
duration: dur_rescale(self.duration),
time_base: new_time_base,
data: self.data.clone(),
flags: self.flags.clone(),
stream_index: self.stream_index,
}
}
}
pub struct PacketBuilder {
stream_index: u32,
time_base: (u32, u32),
pts_counter: u64,
dts_counter: u64,
frame_duration: u32,
}
impl PacketBuilder {
#[must_use]
pub fn new(stream_index: u32, time_base: (u32, u32), fps: f32) -> Self {
let (num, den) = time_base;
let frame_duration = if num == 0 || fps <= 0.0 {
1
} else {
((den as f64 / (fps as f64 * num as f64)).round() as u32).max(1)
};
Self {
stream_index,
time_base,
pts_counter: 0,
dts_counter: 0,
frame_duration,
}
}
pub fn build_video_frame(&mut self, data: Vec<u8>, keyframe: bool) -> CodecPacket {
let pkt = CodecPacket {
pts: self.pts_counter,
dts: self.dts_counter,
duration: self.frame_duration,
time_base: self.time_base,
data,
flags: PacketFlags {
keyframe,
corrupt: false,
discard: false,
},
stream_index: self.stream_index,
};
self.pts_counter = self.pts_counter.saturating_add(self.frame_duration as u64);
self.dts_counter = self.dts_counter.saturating_add(self.frame_duration as u64);
pkt
}
pub fn build_audio_frame(&mut self, data: Vec<u8>, samples: u32) -> CodecPacket {
let duration = if samples > 0 {
samples
} else {
self.frame_duration
};
let pkt = CodecPacket {
pts: self.pts_counter,
dts: self.dts_counter,
duration,
time_base: self.time_base,
data,
flags: PacketFlags {
keyframe: false,
corrupt: false,
discard: false,
},
stream_index: self.stream_index,
};
self.pts_counter = self.pts_counter.saturating_add(duration as u64);
self.dts_counter = self.dts_counter.saturating_add(duration as u64);
pkt
}
#[must_use]
pub fn next_pts(&self) -> u64 {
self.pts_counter
}
#[must_use]
pub fn frame_duration(&self) -> u32 {
self.frame_duration
}
}
#[derive(Debug)]
struct HeapEntry(u64, u64, CodecPacket);
impl PartialEq for HeapEntry {
fn eq(&self, other: &Self) -> bool {
self.0 == other.0 && self.1 == other.1
}
}
impl Eq for HeapEntry {}
impl PartialOrd for HeapEntry {
fn partial_cmp(&self, other: &Self) -> Option<std::cmp::Ordering> {
Some(self.cmp(other))
}
}
impl Ord for HeapEntry {
fn cmp(&self, other: &Self) -> std::cmp::Ordering {
(self.0, self.1).cmp(&(other.0, other.1))
}
}
pub struct PacketReorderer {
buffer: BinaryHeap<Reverse<HeapEntry>>,
max_buffer: usize,
}
impl PacketReorderer {
#[must_use]
pub fn new(max_buffer: usize) -> Self {
Self {
buffer: BinaryHeap::with_capacity(max_buffer + 1),
max_buffer: max_buffer.max(1),
}
}
pub fn push(&mut self, pkt: CodecPacket) {
let entry = HeapEntry(pkt.pts, pkt.dts, pkt);
self.buffer.push(Reverse(entry));
}
pub fn pop_ready(&mut self) -> Option<CodecPacket> {
if self.buffer.len() >= self.max_buffer {
self.buffer.pop().map(|Reverse(HeapEntry(_, _, pkt))| pkt)
} else {
None
}
}
pub fn drain(&mut self) -> Vec<CodecPacket> {
let mut out = Vec::with_capacity(self.buffer.len());
while let Some(Reverse(HeapEntry(_, _, pkt))) = self.buffer.pop() {
out.push(pkt);
}
out
}
#[must_use]
pub fn len(&self) -> usize {
self.buffer.len()
}
#[must_use]
pub fn is_empty(&self) -> bool {
self.buffer.is_empty()
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn builder_first_pts_zero() {
let mut b = PacketBuilder::new(0, (1, 90_000), 30.0);
let p = b.build_video_frame(vec![0u8; 10], true);
assert_eq!(p.pts, 0, "first packet PTS must be 0");
assert_eq!(p.dts, 0, "first packet DTS must be 0");
}
#[test]
fn builder_pts_advances() {
let mut b = PacketBuilder::new(0, (1, 90_000), 30.0);
let dur = b.frame_duration();
let p0 = b.build_video_frame(vec![], true);
let p1 = b.build_video_frame(vec![], false);
assert_eq!(
p1.pts - p0.pts,
dur as u64,
"PTS must advance by frame_duration"
);
}
#[test]
fn builder_keyframe_flag() {
let mut b = PacketBuilder::new(1, (1, 90_000), 25.0);
let key = b.build_video_frame(vec![], true);
let non_key = b.build_video_frame(vec![], false);
assert!(key.flags.keyframe);
assert!(!non_key.flags.keyframe);
}
#[test]
fn builder_stream_index() {
let mut b = PacketBuilder::new(42, (1, 44_100), 25.0);
let p = b.build_audio_frame(vec![0u8; 4], 1024);
assert_eq!(p.stream_index, 42);
}
#[test]
fn builder_audio_no_keyframe() {
let mut b = PacketBuilder::new(1, (1, 44_100), 0.0);
let p = b.build_audio_frame(vec![], 1024);
assert!(!p.flags.keyframe);
}
#[test]
fn builder_audio_duration_from_samples() {
let mut b = PacketBuilder::new(1, (1, 48_000), 25.0);
let p = b.build_audio_frame(vec![], 960);
assert_eq!(p.duration, 960, "audio duration must equal sample count");
}
#[test]
fn pts_secs_conversion() {
let pkt = CodecPacket {
pts: 90_000,
dts: 90_000,
duration: 3000,
time_base: (1, 90_000),
data: vec![],
flags: PacketFlags::default(),
stream_index: 0,
};
let secs = pkt.pts_secs();
assert!(
(secs - 1.0).abs() < 1e-9,
"pts_secs should be 1.0, got {secs}"
);
}
#[test]
fn dts_secs_conversion() {
let pkt = CodecPacket {
pts: 45_000,
dts: 45_000,
duration: 3000,
time_base: (1, 90_000),
data: vec![],
flags: PacketFlags::default(),
stream_index: 0,
};
assert!((pkt.dts_secs() - 0.5).abs() < 1e-9);
}
#[test]
fn rebase_90k_to_1000() {
let pkt = CodecPacket {
pts: 90_000,
dts: 90_000,
duration: 3_000,
time_base: (1, 90_000),
data: vec![],
flags: PacketFlags::default(),
stream_index: 0,
};
let rebased = pkt.rebase((1, 1_000));
assert_eq!(
rebased.pts, 1_000,
"90000 ticks @ 1/90000 = 1000 ticks @ 1/1000"
);
assert_eq!(rebased.duration, 33, "3000/90000 * 1000 ≈ 33 ms");
}
#[test]
fn reorderer_empty_returns_none() {
let mut r = PacketReorderer::new(4);
assert!(r.pop_ready().is_none());
}
#[test]
fn reorderer_pts_order() {
let mut r = PacketReorderer::new(3);
for (pts, dts) in [(0, 0), (3, 1), (1, 2), (2, 3)] {
let pkt = CodecPacket {
pts,
dts,
duration: 1,
time_base: (1, 90_000),
data: vec![],
flags: PacketFlags::default(),
stream_index: 0,
};
r.push(pkt);
}
let mut pts_order = Vec::new();
while let Some(p) = r.pop_ready() {
pts_order.push(p.pts);
}
let remaining = r.drain();
for p in remaining {
pts_order.push(p.pts);
}
let mut sorted = pts_order.clone();
sorted.sort_unstable();
assert_eq!(
pts_order, sorted,
"packets must emerge in PTS ascending order"
);
}
#[test]
fn reorderer_drain_all() {
let mut r = PacketReorderer::new(8);
for i in 0..5_u64 {
let pkt = CodecPacket {
pts: 4 - i, dts: i,
duration: 1,
time_base: (1, 25),
data: vec![],
flags: PacketFlags::default(),
stream_index: 0,
};
r.push(pkt);
}
let drained = r.drain();
assert_eq!(drained.len(), 5, "drain must return all 5 packets");
let pts: Vec<u64> = drained.iter().map(|p| p.pts).collect();
let mut sorted = pts.clone();
sorted.sort_unstable();
assert_eq!(pts, sorted, "drained packets must be in PTS order");
}
}