use std::cmp::Reverse;
use std::collections::BinaryHeap;
use crate::error::{CodecError, CodecResult};
#[derive(Debug, Clone, PartialEq, Eq)]
pub struct QueuedFrame {
pub pts: i64,
pub frame_type: QueueFrameType,
pub data: Vec<u8>,
}
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
pub enum QueueFrameType {
Intra,
Inter,
BiPredicted,
}
#[derive(Debug, Clone)]
pub struct ReadyFrame {
pub pts: i64,
pub dts: i64,
pub frame_type: QueueFrameType,
pub data: Vec<u8>,
}
#[derive(Debug, Default)]
pub struct FrameQueue {
heap: BinaryHeap<Reverse<PtsOrdFrame>>,
pts_set: std::collections::BTreeSet<i64>,
}
#[derive(Debug, Clone, PartialEq, Eq)]
struct PtsOrdFrame(QueuedFrame);
impl PartialOrd for PtsOrdFrame {
fn partial_cmp(&self, other: &Self) -> Option<std::cmp::Ordering> {
Some(self.cmp(other))
}
}
impl Ord for PtsOrdFrame {
fn cmp(&self, other: &Self) -> std::cmp::Ordering {
self.0.pts.cmp(&other.0.pts)
}
}
impl FrameQueue {
pub fn new() -> Self {
Self::default()
}
pub fn push(&mut self, frame: QueuedFrame) -> CodecResult<()> {
if self.pts_set.contains(&frame.pts) {
return Err(CodecError::InvalidParameter(format!(
"duplicate PTS {} in frame queue",
frame.pts
)));
}
self.pts_set.insert(frame.pts);
self.heap.push(Reverse(PtsOrdFrame(frame)));
Ok(())
}
pub fn pop(&mut self) -> Option<QueuedFrame> {
let Reverse(PtsOrdFrame(frame)) = self.heap.pop()?;
self.pts_set.remove(&frame.pts);
Some(frame)
}
pub fn peek_pts(&self) -> Option<i64> {
self.heap.peek().map(|Reverse(PtsOrdFrame(f))| f.pts)
}
pub fn len(&self) -> usize {
self.heap.len()
}
pub fn is_empty(&self) -> bool {
self.heap.is_empty()
}
pub fn drain_ordered(&mut self) -> Vec<QueuedFrame> {
let mut result = Vec::with_capacity(self.heap.len());
while let Some(f) = self.pop() {
result.push(f);
}
result
}
}
#[derive(Debug, Clone)]
pub struct ReorderConfig {
pub max_b_frames: usize,
pub timebase_num: u32,
pub timebase_den: u32,
pub min_dts_delta: i64,
}
impl Default for ReorderConfig {
fn default() -> Self {
Self {
max_b_frames: 2,
timebase_num: 1,
timebase_den: 90_000, min_dts_delta: 3000, }
}
}
#[derive(Debug)]
pub struct BFrameReorderBuffer {
config: ReorderConfig,
pending: Vec<QueuedFrame>,
next_dts: Option<i64>,
dts_counter: i64,
output: std::collections::VecDeque<ReadyFrame>,
}
impl BFrameReorderBuffer {
pub fn new(config: ReorderConfig) -> Self {
Self {
config,
pending: Vec::new(),
next_dts: None,
dts_counter: 0,
output: std::collections::VecDeque::new(),
}
}
pub fn default_config() -> Self {
Self::new(ReorderConfig::default())
}
pub fn push(&mut self, frame: QueuedFrame) {
if self.next_dts.is_none() {
let offset = (self.config.max_b_frames as i64) * self.config.min_dts_delta;
let initial_dts = frame.pts - offset;
self.next_dts = Some(initial_dts);
self.dts_counter = initial_dts;
}
match frame.frame_type {
QueueFrameType::Intra | QueueFrameType::Inter => {
self.flush_pending_b_frames();
self.emit_frame(frame);
}
QueueFrameType::BiPredicted => {
if self.config.max_b_frames == 0 {
self.emit_frame(frame);
} else {
self.pending.push(frame);
if self.pending.len() >= self.config.max_b_frames {
self.flush_pending_b_frames();
}
}
}
}
}
pub fn flush(&mut self) {
self.flush_pending_b_frames();
}
pub fn pop(&mut self) -> Option<ReadyFrame> {
self.output.pop_front()
}
pub fn ready_len(&self) -> usize {
self.output.len()
}
pub fn pending_len(&self) -> usize {
self.pending.len()
}
fn flush_pending_b_frames(&mut self) {
self.pending.sort_by_key(|f| f.pts);
let frames: Vec<_> = self.pending.drain(..).collect();
for f in frames {
self.emit_frame(f);
}
}
fn emit_frame(&mut self, frame: QueuedFrame) {
let dts = self.dts_counter;
self.dts_counter += self.config.min_dts_delta;
self.output.push_back(ReadyFrame {
pts: frame.pts,
dts,
frame_type: frame.frame_type,
data: frame.data,
});
}
}
#[derive(Debug)]
pub struct DtsCalculator {
min_delta: i64,
max_b_frames: usize,
next_dts: Option<i64>,
}
impl DtsCalculator {
pub fn new(min_delta: i64, max_b_frames: usize) -> CodecResult<Self> {
if min_delta <= 0 {
return Err(CodecError::InvalidParameter(
"DtsCalculator: min_delta must be positive".into(),
));
}
Ok(Self {
min_delta,
max_b_frames,
next_dts: None,
})
}
pub fn next(&mut self, pts: i64, _is_keyframe: bool) -> i64 {
let dts = match self.next_dts {
None => {
let offset = (self.max_b_frames as i64) * self.min_delta;
let initial = pts - offset;
self.next_dts = Some(initial + self.min_delta);
initial
}
Some(ref mut counter) => {
let dts = *counter;
*counter += self.min_delta;
dts
}
};
dts
}
pub fn compute_batch(&mut self, frames: &[(i64, bool)]) -> Vec<i64> {
frames
.iter()
.map(|&(pts, is_key)| self.next(pts, is_key))
.collect()
}
pub fn reset(&mut self) {
self.next_dts = None;
}
}
#[cfg(test)]
mod tests {
use super::*;
fn make_frame(pts: i64, ft: QueueFrameType) -> QueuedFrame {
QueuedFrame {
pts,
frame_type: ft,
data: vec![pts as u8],
}
}
#[test]
fn test_frame_queue_push_pop_ordered() {
let mut q = FrameQueue::new();
q.push(make_frame(200, QueueFrameType::Inter)).unwrap();
q.push(make_frame(0, QueueFrameType::Intra)).unwrap();
q.push(make_frame(100, QueueFrameType::BiPredicted))
.unwrap();
assert_eq!(q.pop().unwrap().pts, 0);
assert_eq!(q.pop().unwrap().pts, 100);
assert_eq!(q.pop().unwrap().pts, 200);
assert!(q.pop().is_none());
}
#[test]
fn test_frame_queue_duplicate_pts_error() {
let mut q = FrameQueue::new();
q.push(make_frame(100, QueueFrameType::Intra)).unwrap();
let result = q.push(make_frame(100, QueueFrameType::Inter));
assert!(result.is_err());
}
#[test]
fn test_frame_queue_drain_ordered() {
let mut q = FrameQueue::new();
for pts in [500i64, 100, 300, 0, 200] {
q.push(make_frame(pts, QueueFrameType::Inter)).unwrap();
}
let drained = q.drain_ordered();
let pts_seq: Vec<i64> = drained.iter().map(|f| f.pts).collect();
assert_eq!(pts_seq, vec![0, 100, 200, 300, 500]);
}
#[test]
fn test_frame_queue_peek_pts() {
let mut q = FrameQueue::new();
assert_eq!(q.peek_pts(), None);
q.push(make_frame(50, QueueFrameType::Intra)).unwrap();
q.push(make_frame(10, QueueFrameType::Inter)).unwrap();
assert_eq!(q.peek_pts(), Some(10));
}
#[test]
fn test_b_frame_reorder_anchor_before_b() {
let cfg = ReorderConfig {
max_b_frames: 2,
min_dts_delta: 1,
..Default::default()
};
let mut buf = BFrameReorderBuffer::new(cfg);
buf.push(make_frame(0, QueueFrameType::Intra));
buf.push(make_frame(1, QueueFrameType::BiPredicted));
buf.push(make_frame(2, QueueFrameType::BiPredicted));
buf.push(make_frame(3, QueueFrameType::Inter));
buf.flush();
let mut out = Vec::new();
while let Some(f) = buf.pop() {
out.push(f);
}
assert!(!out.is_empty());
for w in out.windows(2) {
assert!(w[1].dts >= w[0].dts, "DTS must be non-decreasing");
}
}
#[test]
fn test_b_frame_reorder_dts_leq_pts() {
let cfg = ReorderConfig {
max_b_frames: 2,
min_dts_delta: 3000,
..Default::default()
};
let mut buf = BFrameReorderBuffer::new(cfg);
let pts_sequence = [0i64, 3000, 6000, 9000, 12000];
for (i, &pts) in pts_sequence.iter().enumerate() {
let ft = if i % 3 == 0 {
QueueFrameType::Intra
} else if i % 3 == 1 {
QueueFrameType::BiPredicted
} else {
QueueFrameType::Inter
};
buf.push(make_frame(pts, ft));
}
buf.flush();
while let Some(f) = buf.pop() {
assert!(f.dts <= f.pts, "DTS ({}) must be <= PTS ({})", f.dts, f.pts);
}
}
#[test]
fn test_dts_calculator_basic() {
let mut calc = DtsCalculator::new(3000, 2).unwrap();
let pts_vals = [6000i64, 9000, 12000, 15000];
let frames: Vec<(i64, bool)> = pts_vals
.iter()
.enumerate()
.map(|(i, &p)| (p, i == 0))
.collect();
let dts = calc.compute_batch(&frames);
assert_eq!(dts[0], 0);
for w in dts.windows(2) {
assert_eq!(w[1] - w[0], 3000);
}
}
#[test]
fn test_dts_calculator_invalid_delta() {
let result = DtsCalculator::new(0, 2);
assert!(result.is_err());
let result2 = DtsCalculator::new(-1, 2);
assert!(result2.is_err());
}
#[test]
fn test_dts_calculator_reset() {
let mut calc = DtsCalculator::new(1000, 1).unwrap();
let dts1 = calc.next(5000, true);
calc.reset();
let dts2 = calc.next(5000, true);
assert_eq!(dts1, dts2);
}
#[test]
fn test_no_b_frames_passthrough() {
let cfg = ReorderConfig {
max_b_frames: 0,
min_dts_delta: 1,
..Default::default()
};
let mut buf = BFrameReorderBuffer::new(cfg);
for pts in [0i64, 1, 2, 3] {
buf.push(make_frame(pts, QueueFrameType::BiPredicted));
}
buf.flush();
let mut pts_out = Vec::new();
while let Some(f) = buf.pop() {
pts_out.push(f.pts);
}
assert_eq!(pts_out.len(), 4);
}
}