Skip to main content

oximedia_codec/
packet_builder.rs

1//! Codec packet and frame building utilities.
2//!
3//! This module provides:
4//!
5//! - [`CodecPacket`] — a rich packet type carrying timestamps, flags, and payload.
6//! - [`PacketFlags`] — per-packet Boolean flags (keyframe, corrupt, discard).
7//! - [`PacketBuilder`] — a stateful helper that produces correctly-timestamped
8//!   [`CodecPacket`]s for video and audio streams.
9//! - [`PacketReorderer`] — a bounded priority queue that converts DTS-ordered
10//!   packets (from the encoder) into PTS order (for the muxer / consumer).
11//!
12//! # Timestamp arithmetic
13//!
14//! All timestamps are expressed in **time-base units**.  A time base of
15//! `(1, 90000)` means each unit represents 1/90000 of a second.  The helpers
16//! [`CodecPacket::pts_secs`] and [`CodecPacket::dts_secs`] convert to seconds.
17//! [`CodecPacket::rebase`] rescales all timestamps to a different time base.
18//!
19//! # Example
20//!
21//! ```rust
22//! use oximedia_codec::packet_builder::PacketBuilder;
23//!
24//! // Build video packets at 30 fps with 90 kHz time base.
25//! let mut builder = PacketBuilder::new(0, (1, 90_000), 30.0);
26//! let pkt = builder.build_video_frame(vec![0xAB; 1024], true);
27//! assert!(pkt.flags.keyframe);
28//! assert_eq!(pkt.stream_index, 0);
29//! ```
30
31#![allow(clippy::cast_lossless)]
32#![allow(clippy::cast_precision_loss)]
33#![allow(clippy::cast_possible_truncation)]
34#![allow(clippy::cast_sign_loss)]
35
36use std::cmp::Reverse;
37use std::collections::BinaryHeap;
38
39// ──────────────────────────────────────────────
40// PacketFlags
41// ──────────────────────────────────────────────
42
43/// Per-packet boolean flags.
44#[derive(Debug, Clone, Default, PartialEq, Eq)]
45pub struct PacketFlags {
46    /// The packet contains a keyframe (random access point).
47    pub keyframe: bool,
48    /// The packet data may be corrupt or partially lost.
49    pub corrupt: bool,
50    /// The packet should be decoded but not displayed.
51    pub discard: bool,
52}
53
54// ──────────────────────────────────────────────
55// CodecPacket
56// ──────────────────────────────────────────────
57
58/// A single codec-level packet (compressed frame data + timestamps).
59///
60/// Timestamps are stored as unsigned integers in time-base units.  The
61/// time base is carried in the packet itself so that consumers do not need
62/// out-of-band information to interpret the timestamps.
63#[derive(Debug, Clone)]
64pub struct CodecPacket {
65    /// Presentation timestamp — when this frame should be displayed.
66    pub pts: u64,
67    /// Decode timestamp — when this frame must be decoded.
68    pub dts: u64,
69    /// Frame duration in time-base units.
70    pub duration: u32,
71    /// Time base as `(numerator, denominator)`.
72    /// Seconds = `pts * numerator / denominator`.
73    pub time_base: (u32, u32),
74    /// Compressed frame payload.
75    pub data: Vec<u8>,
76    /// Boolean flags.
77    pub flags: PacketFlags,
78    /// Index of the stream this packet belongs to.
79    pub stream_index: u32,
80}
81
82impl CodecPacket {
83    /// PTS in seconds.
84    #[must_use]
85    pub fn pts_secs(&self) -> f64 {
86        let (num, den) = self.time_base;
87        if den == 0 {
88            return 0.0;
89        }
90        self.pts as f64 * num as f64 / den as f64
91    }
92
93    /// DTS in seconds.
94    #[must_use]
95    pub fn dts_secs(&self) -> f64 {
96        let (num, den) = self.time_base;
97        if den == 0 {
98            return 0.0;
99        }
100        self.dts as f64 * num as f64 / den as f64
101    }
102
103    /// Rescale all timestamps to `new_time_base`, returning a new packet.
104    ///
105    /// The rescaling uses 64-bit integer arithmetic with rounding, matching
106    /// the behaviour of `av_rescale_rnd(…, AV_ROUND_NEAR_INF)`.
107    #[must_use]
108    pub fn rebase(&self, new_time_base: (u32, u32)) -> Self {
109        let (old_num, old_den) = self.time_base;
110        let (new_num, new_den) = new_time_base;
111
112        // Convert: value_in_new = value_in_old * old_num * new_den / (old_den * new_num)
113        // Intermediate values are computed in u128 to prevent overflow.
114        let rescale = |v: u64| -> u64 {
115            if old_den == 0 || new_num == 0 {
116                return v;
117            }
118            let numerator = v as u128 * old_num as u128 * new_den as u128;
119            let denominator = old_den as u128 * new_num as u128;
120            if denominator == 0 {
121                return v;
122            }
123            ((numerator + denominator / 2) / denominator) as u64
124        };
125
126        let dur_rescale = |v: u32| -> u32 {
127            if old_den == 0 || new_num == 0 {
128                return v;
129            }
130            let numerator = v as u128 * old_num as u128 * new_den as u128;
131            let denominator = old_den as u128 * new_num as u128;
132            if denominator == 0 {
133                return v;
134            }
135            ((numerator + denominator / 2) / denominator).min(u32::MAX as u128) as u32
136        };
137
138        Self {
139            pts: rescale(self.pts),
140            dts: rescale(self.dts),
141            duration: dur_rescale(self.duration),
142            time_base: new_time_base,
143            data: self.data.clone(),
144            flags: self.flags.clone(),
145            stream_index: self.stream_index,
146        }
147    }
148}
149
150// ──────────────────────────────────────────────
151// PacketBuilder
152// ──────────────────────────────────────────────
153
154/// A stateful helper for building correctly-timestamped [`CodecPacket`]s.
155///
156/// Create one builder per stream.  Call [`build_video_frame`] for each video
157/// frame or [`build_audio_frame`] for each audio frame; the builder tracks
158/// the running PTS/DTS automatically.
159///
160/// [`build_video_frame`]: PacketBuilder::build_video_frame
161/// [`build_audio_frame`]: PacketBuilder::build_audio_frame
162pub struct PacketBuilder {
163    /// Stream index embedded in every produced packet.
164    stream_index: u32,
165    /// Time base embedded in every produced packet.
166    time_base: (u32, u32),
167    /// Next PTS to assign (incremented by `frame_duration` after each call).
168    pts_counter: u64,
169    /// Next DTS to assign.
170    dts_counter: u64,
171    /// Duration of one video frame in time-base units.
172    frame_duration: u32,
173}
174
175impl PacketBuilder {
176    /// Create a new builder.
177    ///
178    /// - `stream_index`: stream index embedded into every packet.
179    /// - `time_base`: `(numerator, denominator)` of the stream time base.
180    /// - `fps`: frame rate; used to compute `frame_duration`.
181    ///
182    /// `frame_duration` is computed as
183    /// `round(time_base.denominator / (fps * time_base.numerator))`,
184    /// clamped to at least 1.
185    #[must_use]
186    pub fn new(stream_index: u32, time_base: (u32, u32), fps: f32) -> Self {
187        let (num, den) = time_base;
188        let frame_duration = if num == 0 || fps <= 0.0 {
189            1
190        } else {
191            ((den as f64 / (fps as f64 * num as f64)).round() as u32).max(1)
192        };
193
194        Self {
195            stream_index,
196            time_base,
197            pts_counter: 0,
198            dts_counter: 0,
199            frame_duration,
200        }
201    }
202
203    /// Build a video frame packet and advance the timestamp counters.
204    ///
205    /// The PTS and DTS of the returned packet reflect the state **before** the
206    /// counters are advanced, so the first packet always has PTS/DTS = 0.
207    pub fn build_video_frame(&mut self, data: Vec<u8>, keyframe: bool) -> CodecPacket {
208        let pkt = CodecPacket {
209            pts: self.pts_counter,
210            dts: self.dts_counter,
211            duration: self.frame_duration,
212            time_base: self.time_base,
213            data,
214            flags: PacketFlags {
215                keyframe,
216                corrupt: false,
217                discard: false,
218            },
219            stream_index: self.stream_index,
220        };
221
222        self.pts_counter = self.pts_counter.saturating_add(self.frame_duration as u64);
223        self.dts_counter = self.dts_counter.saturating_add(self.frame_duration as u64);
224        pkt
225    }
226
227    /// Build an audio frame packet.
228    ///
229    /// `samples` is the number of PCM samples in the frame.  The duration is
230    /// derived from `samples * time_base.numerator / sample_rate`, but since
231    /// `PacketBuilder` is not audio-sample-rate aware, the caller should pass
232    /// the actual per-frame sample count and the method uses `frame_duration`
233    /// as a fallback when `samples == 0`.
234    ///
235    /// For audio the `keyframe` flag is always `false` (all audio frames are
236    /// independently decodable).
237    pub fn build_audio_frame(&mut self, data: Vec<u8>, samples: u32) -> CodecPacket {
238        let duration = if samples > 0 {
239            samples
240        } else {
241            self.frame_duration
242        };
243
244        let pkt = CodecPacket {
245            pts: self.pts_counter,
246            dts: self.dts_counter,
247            duration,
248            time_base: self.time_base,
249            data,
250            flags: PacketFlags {
251                keyframe: false,
252                corrupt: false,
253                discard: false,
254            },
255            stream_index: self.stream_index,
256        };
257
258        self.pts_counter = self.pts_counter.saturating_add(duration as u64);
259        self.dts_counter = self.dts_counter.saturating_add(duration as u64);
260        pkt
261    }
262
263    /// Return the current PTS counter (PTS that will be assigned to the next packet).
264    #[must_use]
265    pub fn next_pts(&self) -> u64 {
266        self.pts_counter
267    }
268
269    /// Return the current frame duration in time-base units.
270    #[must_use]
271    pub fn frame_duration(&self) -> u32 {
272        self.frame_duration
273    }
274}
275
276// ──────────────────────────────────────────────
277// PacketReorderer
278// ──────────────────────────────────────────────
279
280/// A wrapper that makes [`CodecPacket`] comparable by PTS for the heap.
281///
282/// The heap orders by `Reverse((pts, dts))` so that the packet with the
283/// smallest PTS is always at the top.
284#[derive(Debug)]
285struct HeapEntry(u64, u64, CodecPacket); // (pts, dts, packet)
286
287impl PartialEq for HeapEntry {
288    fn eq(&self, other: &Self) -> bool {
289        self.0 == other.0 && self.1 == other.1
290    }
291}
292
293impl Eq for HeapEntry {}
294
295impl PartialOrd for HeapEntry {
296    fn partial_cmp(&self, other: &Self) -> Option<std::cmp::Ordering> {
297        Some(self.cmp(other))
298    }
299}
300
301impl Ord for HeapEntry {
302    fn cmp(&self, other: &Self) -> std::cmp::Ordering {
303        // Order by PTS ascending, break ties by DTS ascending.
304        (self.0, self.1).cmp(&(other.0, other.1))
305    }
306}
307
308/// Reorders DTS-ordered packets (as produced by encoders with B-frames) into
309/// PTS order (required by muxers and decoders operating on display order).
310///
311/// Internally uses a min-heap keyed on PTS.  A packet is considered *ready*
312/// once the heap has accumulated at least `max_buffer` entries, which bounds
313/// the maximum PTS reordering delay.
314///
315/// # Flushing
316///
317/// Call [`drain`] at end-of-stream to retrieve all remaining packets in PTS
318/// order.
319///
320/// [`drain`]: PacketReorderer::drain
321pub struct PacketReorderer {
322    /// Min-heap of `Reverse(HeapEntry)` so the smallest PTS surfaces first.
323    buffer: BinaryHeap<Reverse<HeapEntry>>,
324    /// Maximum packets buffered before [`pop_ready`] will return a packet.
325    ///
326    /// [`pop_ready`]: PacketReorderer::pop_ready
327    max_buffer: usize,
328}
329
330impl PacketReorderer {
331    /// Create a new reorderer.
332    ///
333    /// `max_buffer` controls the maximum reorder window.  A value of 4–8 is
334    /// appropriate for streams with up to 3 consecutive B-frames.
335    #[must_use]
336    pub fn new(max_buffer: usize) -> Self {
337        Self {
338            buffer: BinaryHeap::with_capacity(max_buffer + 1),
339            max_buffer: max_buffer.max(1),
340        }
341    }
342
343    /// Push a packet into the reorder buffer.
344    pub fn push(&mut self, pkt: CodecPacket) {
345        let entry = HeapEntry(pkt.pts, pkt.dts, pkt);
346        self.buffer.push(Reverse(entry));
347    }
348
349    /// Pop the packet with the lowest PTS if the buffer is full enough to
350    /// guarantee it is the next in display order.
351    ///
352    /// Returns `None` if the buffer is smaller than `max_buffer`.
353    pub fn pop_ready(&mut self) -> Option<CodecPacket> {
354        if self.buffer.len() >= self.max_buffer {
355            self.buffer.pop().map(|Reverse(HeapEntry(_, _, pkt))| pkt)
356        } else {
357            None
358        }
359    }
360
361    /// Drain all remaining packets from the buffer, ordered by PTS ascending.
362    ///
363    /// The buffer is empty after this call.
364    pub fn drain(&mut self) -> Vec<CodecPacket> {
365        let mut out = Vec::with_capacity(self.buffer.len());
366        while let Some(Reverse(HeapEntry(_, _, pkt))) = self.buffer.pop() {
367            out.push(pkt);
368        }
369        out
370    }
371
372    /// Number of packets currently buffered.
373    #[must_use]
374    pub fn len(&self) -> usize {
375        self.buffer.len()
376    }
377
378    /// Returns `true` if the buffer contains no packets.
379    #[must_use]
380    pub fn is_empty(&self) -> bool {
381        self.buffer.is_empty()
382    }
383}
384
385// ──────────────────────────────────────────────
386// Tests
387// ──────────────────────────────────────────────
388
389#[cfg(test)]
390mod tests {
391    use super::*;
392
393    // ── 1. PacketBuilder: first packet has PTS/DTS = 0 ──────────────────────
394
395    #[test]
396    fn builder_first_pts_zero() {
397        let mut b = PacketBuilder::new(0, (1, 90_000), 30.0);
398        let p = b.build_video_frame(vec![0u8; 10], true);
399        assert_eq!(p.pts, 0, "first packet PTS must be 0");
400        assert_eq!(p.dts, 0, "first packet DTS must be 0");
401    }
402
403    // ── 2. PacketBuilder: consecutive frames advance PTS by frame_duration ───
404
405    #[test]
406    fn builder_pts_advances() {
407        let mut b = PacketBuilder::new(0, (1, 90_000), 30.0);
408        let dur = b.frame_duration();
409        let p0 = b.build_video_frame(vec![], true);
410        let p1 = b.build_video_frame(vec![], false);
411        assert_eq!(
412            p1.pts - p0.pts,
413            dur as u64,
414            "PTS must advance by frame_duration"
415        );
416    }
417
418    // ── 3. PacketBuilder: keyframe flag is propagated ────────────────────────
419
420    #[test]
421    fn builder_keyframe_flag() {
422        let mut b = PacketBuilder::new(1, (1, 90_000), 25.0);
423        let key = b.build_video_frame(vec![], true);
424        let non_key = b.build_video_frame(vec![], false);
425        assert!(key.flags.keyframe);
426        assert!(!non_key.flags.keyframe);
427    }
428
429    // ── 4. PacketBuilder: stream_index is embedded ───────────────────────────
430
431    #[test]
432    fn builder_stream_index() {
433        let mut b = PacketBuilder::new(42, (1, 44_100), 25.0);
434        let p = b.build_audio_frame(vec![0u8; 4], 1024);
435        assert_eq!(p.stream_index, 42);
436    }
437
438    // ── 5. PacketBuilder: audio frame keyframe is always false ───────────────
439
440    #[test]
441    fn builder_audio_no_keyframe() {
442        let mut b = PacketBuilder::new(1, (1, 44_100), 0.0);
443        let p = b.build_audio_frame(vec![], 1024);
444        assert!(!p.flags.keyframe);
445    }
446
447    // ── 6. PacketBuilder: audio duration equals sample count ─────────────────
448
449    #[test]
450    fn builder_audio_duration_from_samples() {
451        let mut b = PacketBuilder::new(1, (1, 48_000), 25.0);
452        let p = b.build_audio_frame(vec![], 960);
453        assert_eq!(p.duration, 960, "audio duration must equal sample count");
454    }
455
456    // ── 7. CodecPacket::pts_secs: correct conversion ─────────────────────────
457
458    #[test]
459    fn pts_secs_conversion() {
460        let pkt = CodecPacket {
461            pts: 90_000,
462            dts: 90_000,
463            duration: 3000,
464            time_base: (1, 90_000),
465            data: vec![],
466            flags: PacketFlags::default(),
467            stream_index: 0,
468        };
469        let secs = pkt.pts_secs();
470        assert!(
471            (secs - 1.0).abs() < 1e-9,
472            "pts_secs should be 1.0, got {secs}"
473        );
474    }
475
476    // ── 8. CodecPacket::dts_secs: correct conversion ─────────────────────────
477
478    #[test]
479    fn dts_secs_conversion() {
480        let pkt = CodecPacket {
481            pts: 45_000,
482            dts: 45_000,
483            duration: 3000,
484            time_base: (1, 90_000),
485            data: vec![],
486            flags: PacketFlags::default(),
487            stream_index: 0,
488        };
489        assert!((pkt.dts_secs() - 0.5).abs() < 1e-9);
490    }
491
492    // ── 9. CodecPacket::rebase: 90kHz → 1/1000 ──────────────────────────────
493
494    #[test]
495    fn rebase_90k_to_1000() {
496        let pkt = CodecPacket {
497            pts: 90_000,
498            dts: 90_000,
499            duration: 3_000,
500            time_base: (1, 90_000),
501            data: vec![],
502            flags: PacketFlags::default(),
503            stream_index: 0,
504        };
505        let rebased = pkt.rebase((1, 1_000));
506        assert_eq!(
507            rebased.pts, 1_000,
508            "90000 ticks @ 1/90000 = 1000 ticks @ 1/1000"
509        );
510        assert_eq!(rebased.duration, 33, "3000/90000 * 1000 ≈ 33 ms");
511    }
512
513    // ── 10. PacketReorderer: empty buffer returns None ───────────────────────
514
515    #[test]
516    fn reorderer_empty_returns_none() {
517        let mut r = PacketReorderer::new(4);
518        assert!(r.pop_ready().is_none());
519    }
520
521    // ── 11. PacketReorderer: packets released in PTS order ──────────────────
522
523    #[test]
524    fn reorderer_pts_order() {
525        let mut r = PacketReorderer::new(3);
526
527        // Push 4 packets with scrambled PTS order (simulating B-frames).
528        for (pts, dts) in [(0, 0), (3, 1), (1, 2), (2, 3)] {
529            let pkt = CodecPacket {
530                pts,
531                dts,
532                duration: 1,
533                time_base: (1, 90_000),
534                data: vec![],
535                flags: PacketFlags::default(),
536                stream_index: 0,
537            };
538            r.push(pkt);
539        }
540
541        // With max_buffer=3 we can pop once buffer >= 3.
542        let mut pts_order = Vec::new();
543        while let Some(p) = r.pop_ready() {
544            pts_order.push(p.pts);
545        }
546        let remaining = r.drain();
547        for p in remaining {
548            pts_order.push(p.pts);
549        }
550
551        let mut sorted = pts_order.clone();
552        sorted.sort_unstable();
553        assert_eq!(
554            pts_order, sorted,
555            "packets must emerge in PTS ascending order"
556        );
557    }
558
559    // ── 12. PacketReorderer::drain: returns all packets ─────────────────────
560
561    #[test]
562    fn reorderer_drain_all() {
563        let mut r = PacketReorderer::new(8);
564        for i in 0..5_u64 {
565            let pkt = CodecPacket {
566                pts: 4 - i, // reverse order
567                dts: i,
568                duration: 1,
569                time_base: (1, 25),
570                data: vec![],
571                flags: PacketFlags::default(),
572                stream_index: 0,
573            };
574            r.push(pkt);
575        }
576        let drained = r.drain();
577        assert_eq!(drained.len(), 5, "drain must return all 5 packets");
578        // Check ascending PTS order after drain.
579        let pts: Vec<u64> = drained.iter().map(|p| p.pts).collect();
580        let mut sorted = pts.clone();
581        sorted.sort_unstable();
582        assert_eq!(pts, sorted, "drained packets must be in PTS order");
583    }
584}