Skip to main content

oximedia_transcode/
multi_track.rs

1//! Multi-track frame-level pipeline executor with DTS-ordered interleaving.
2//!
3//! This module provides [`MultiTrackExecutor`] — the real decode → filter →
4//! encode engine for `oximedia-transcode`.  It connects the `FrameDecoder` /
5//! `FilterGraph` / `FrameEncoder` plumbing from [`pipeline_context`] to a
6//! container-level output via a [`Muxer`] and performs *DTS-ordered
7//! interleaving* across all tracks using a min-heap.
8//!
9//! # Architecture
10//!
11//! ```text
12//!    [FrameDecoder₀] → FilterGraph₀ → [FrameEncoder₀] ─┐
13//!    [FrameDecoder₁] → FilterGraph₁ → [FrameEncoder₁] ─┤→ DTS min-heap → Muxer
14//!           …                                            ┘
15//! ```
16//!
17//! ## Execute loop
18//!
19//! 1. For each active track, call [`FrameDecoder::decode_next`].
20//! 2. Apply the track's [`FilterGraph::apply`] to the decoded frame.
21//! 3. Pass filtered frames to the track's [`FrameEncoder::encode_frame`].
22//! 4. Push resulting encoded bytes as a `StagedPacket` onto the DTS min-heap.
23//! 5. After all tracks are exhausted, flush each encoder.
24//! 6. Pop the heap in DTS order and write every packet to the [`Muxer`].
25//!
26//! The [`MultiTrackExecutor::step`] method performs one packet-cycle (one pass
27//! through all tracks) and pushes ready encoded data into the internal staging
28//! buffer, so a segment or parallel driver can call it externally.
29//!
30//! [`pipeline_context`]: crate::pipeline_context
31
32#![allow(clippy::module_name_repetitions)]
33
34use std::cmp::Reverse;
35use std::collections::BinaryHeap;
36
37use bytes::Bytes;
38use oximedia_container::{Muxer, Packet, PacketFlags, StreamInfo};
39use oximedia_core::{Rational, Timestamp};
40
41use crate::pipeline_context::{FilterGraph, FrameDecoder, FrameEncoder};
42use crate::{Result, TranscodeError};
43
44// ─── PerTrack ─────────────────────────────────────────────────────────────────
45
46/// One logical media track wired through decode → filter → encode.
47///
48/// Created by the caller with concrete decoder, filter graph, and encoder
49/// implementations, then handed to [`MultiTrackExecutor::add_track`].
50pub struct PerTrack {
51    /// The stream index in the output muxer for packets from this track.
52    pub stream_index: usize,
53    /// Decoder for this track.
54    pub decoder: Box<dyn FrameDecoder>,
55    /// Filter graph applied between decode and encode.
56    pub filter_graph: FilterGraph,
57    /// Encoder for this track.
58    pub encoder: Box<dyn FrameEncoder>,
59    /// `true` when the decoder has reported EOF and the encoder has been flushed.
60    pub flushed: bool,
61    /// Frame counter; used to derive synthetic PTS for flush packets.
62    frame_count: u64,
63    /// Accumulated encoded-bytes count (public for stats queries).
64    pub encoded_bytes: u64,
65    /// Accumulated encoded-frame count (public for stats queries).
66    pub encoded_frames: u64,
67    /// Whether this track carries audio (`true`) or video (`false`).
68    ///
69    /// Determined from the first decoded frame and used to populate the
70    /// `is_audio` flag on flush tail-packets — ensuring the muxer receives
71    /// correct stream-type information at EOS even when no frames flow.
72    is_audio: Option<bool>,
73}
74
75impl PerTrack {
76    /// Create a new [`PerTrack`] with the given stream index, decoder,
77    /// filter graph, and encoder.
78    #[must_use]
79    pub fn new(
80        stream_index: usize,
81        decoder: Box<dyn FrameDecoder>,
82        filter_graph: FilterGraph,
83        encoder: Box<dyn FrameEncoder>,
84    ) -> Self {
85        Self {
86            stream_index,
87            decoder,
88            filter_graph,
89            encoder,
90            flushed: false,
91            frame_count: 0,
92            encoded_bytes: 0,
93            encoded_frames: 0,
94            is_audio: None,
95        }
96    }
97
98    /// Create a new [`PerTrack`] whose track type is known at construction.
99    ///
100    /// Use this constructor when the stream kind (audio vs video) is
101    /// available from the container's [`StreamInfo`] before decoding starts,
102    /// so that `flush_encoder` emits packets with the correct type even if
103    /// no frames were decoded (e.g., a very short audio track).
104    #[must_use]
105    pub fn new_typed(
106        stream_index: usize,
107        decoder: Box<dyn FrameDecoder>,
108        filter_graph: FilterGraph,
109        encoder: Box<dyn FrameEncoder>,
110        is_audio: bool,
111    ) -> Self {
112        Self {
113            stream_index,
114            decoder,
115            filter_graph,
116            encoder,
117            flushed: false,
118            frame_count: 0,
119            encoded_bytes: 0,
120            encoded_frames: 0,
121            is_audio: Some(is_audio),
122        }
123    }
124
125    /// Step this track by one frame: decode → filter → encode.
126    ///
127    /// Returns `Ok(Some(TrackEncoded))` when a frame was successfully encoded,
128    /// `Ok(None)` when the decoder produced no frame (EOF or frame dropped by
129    /// filter), or an error if encoding or filter operations fail.
130    fn step_frame(&mut self) -> Result<Option<TrackEncoded>> {
131        if self.flushed || self.decoder.eof() {
132            return Ok(None);
133        }
134
135        let frame = match self.decoder.decode_next() {
136            Some(f) => f,
137            None => return Ok(None),
138        };
139
140        let pts_ms = frame.pts_ms;
141        let is_audio = frame.is_audio;
142
143        // Latch the track kind from the first frame so flush_encoder can use it.
144        if self.is_audio.is_none() {
145            self.is_audio = Some(is_audio);
146        }
147
148        let filtered = match self.filter_graph.apply(frame)? {
149            Some(f) => f,
150            None => {
151                // Frame dropped by filter — counts as dropped, not an error.
152                return Ok(None);
153            }
154        };
155
156        let encoded = self.encoder.encode_frame(&filtered)?;
157        let n = encoded.len() as u64;
158        self.encoded_bytes += n;
159        self.encoded_frames += 1;
160        self.frame_count += 1;
161
162        Ok(Some(TrackEncoded {
163            data: encoded,
164            pts_ms,
165            is_audio,
166        }))
167    }
168
169    /// Flush the encoder and return remaining encoded bytes (if any).
170    ///
171    /// Sets `self.flushed = true` after the first call; subsequent calls are
172    /// no-ops.
173    ///
174    /// The `is_audio` flag on the returned tail-packet is taken from the track
175    /// type latched during [`step_frame`](Self::step_frame) (or set at
176    /// construction via [`new_typed`](Self::new_typed)).  If neither path has
177    /// provided a type yet (a zero-frame track), the flush packet is omitted
178    /// entirely since there is no stream kind to report.
179    fn flush_encoder(&mut self) -> Result<Option<TrackEncoded>> {
180        if self.flushed {
181            return Ok(None);
182        }
183        self.flushed = true;
184        let data = self.encoder.flush()?;
185        if data.is_empty() {
186            return Ok(None);
187        }
188        // If the track type is still unknown (zero-frame track), skip the
189        // flush packet rather than reporting a wrong stream kind to the muxer.
190        let is_audio = match self.is_audio {
191            Some(v) => v,
192            None => return Ok(None),
193        };
194        self.encoded_bytes += data.len() as u64;
195        // Derive a synthetic PTS from the frame count (33 ms/frame ≈ 30 fps).
196        let pts_ms = self.frame_count as i64 * 33;
197        Ok(Some(TrackEncoded {
198            data,
199            pts_ms,
200            is_audio,
201        }))
202    }
203}
204
205// ─── TrackEncoded ─────────────────────────────────────────────────────────────
206
207/// Encoded output produced by a single [`PerTrack::step_frame`] call.
208#[derive(Debug)]
209struct TrackEncoded {
210    data: Vec<u8>,
211    pts_ms: i64,
212    is_audio: bool,
213}
214
215// ─── StagedPacket ─────────────────────────────────────────────────────────────
216
217/// An encoded packet waiting in the DTS min-heap for muxer output.
218#[derive(Debug)]
219struct StagedPacket {
220    /// Effective DTS for heap ordering.
221    dts_ms: i64,
222    /// Stream index for the muxer.
223    stream_index: usize,
224    /// Encoded payload.
225    data: Vec<u8>,
226    /// `true` for audio packets.
227    is_audio: bool,
228}
229
230impl PartialEq for StagedPacket {
231    fn eq(&self, other: &Self) -> bool {
232        self.dts_ms == other.dts_ms && self.stream_index == other.stream_index
233    }
234}
235
236impl Eq for StagedPacket {}
237
238impl PartialOrd for StagedPacket {
239    fn partial_cmp(&self, other: &Self) -> Option<std::cmp::Ordering> {
240        Some(self.cmp(other))
241    }
242}
243
244impl Ord for StagedPacket {
245    fn cmp(&self, other: &Self) -> std::cmp::Ordering {
246        // Primary: DTS ascending; secondary: stream_index ascending for determinism.
247        self.dts_ms
248            .cmp(&other.dts_ms)
249            .then(self.stream_index.cmp(&other.stream_index))
250    }
251}
252
253// ─── MultiTrackStats ──────────────────────────────────────────────────────────
254
255/// Statistics returned by [`MultiTrackExecutor::execute`].
256#[derive(Debug, Clone, Default)]
257pub struct MultiTrackStats {
258    /// Total encoded frames across all tracks.
259    pub total_encoded_frames: u64,
260    /// Total encoded bytes across all tracks.
261    pub total_encoded_bytes: u64,
262    /// Number of packets written to the muxer in DTS order.
263    pub packets_muxed: u64,
264    /// Number of frames dropped by filter graphs.
265    pub frames_dropped: u64,
266}
267
268// ─── MultiTrackExecutor ───────────────────────────────────────────────────────
269
270/// Frame-level multi-track decode → filter → encode executor with
271/// DTS-ordered muxing.
272///
273/// # Usage
274///
275/// ```rust,ignore
276/// use oximedia_transcode::multi_track::{MultiTrackExecutor, PerTrack};
277/// use oximedia_transcode::pipeline_context::{FilterGraph, Frame};
278///
279/// // Supply concrete FrameDecoder / FrameEncoder implementations:
280/// let mut executor = MultiTrackExecutor::new(muxer);
281/// executor.add_track(PerTrack::new(0, decoder0, FilterGraph::new(), encoder0));
282/// executor.add_track(PerTrack::new(1, decoder1, FilterGraph::new(), encoder1));
283/// let stats = executor.execute(&streams).await?;
284/// ```
285pub struct MultiTrackExecutor<M: Muxer> {
286    /// Per-track decode/filter/encode pipelines.
287    tracks: Vec<PerTrack>,
288    /// The output container muxer.
289    muxer: M,
290    /// DTS min-heap: `Reverse` turns `BinaryHeap` (max-heap) into a min-heap.
291    heap: BinaryHeap<Reverse<StagedPacket>>,
292    /// Timebase used for `Timestamp` construction (1 ms resolution by default).
293    timebase: Rational,
294    /// Drain the heap to the muxer every `flush_interval` step cycles.
295    flush_interval: u64,
296    /// Step counter for flush scheduling.
297    step_count: u64,
298    /// `true` after all tracks have reached EOF.
299    tracks_done: bool,
300    /// Accumulated statistics.
301    stats: MultiTrackStats,
302}
303
304impl<M: Muxer> MultiTrackExecutor<M> {
305    /// Default flush interval (drain heap every 30 steps).
306    const DEFAULT_FLUSH_INTERVAL: u64 = 30;
307
308    /// Creates a new executor wrapping `muxer`.
309    ///
310    /// Tracks must be added with [`add_track`](Self::add_track) before calling
311    /// [`execute`](Self::execute) or [`step`](Self::step).
312    pub fn new(muxer: M) -> Self {
313        Self {
314            tracks: Vec::new(),
315            muxer,
316            heap: BinaryHeap::new(),
317            timebase: Rational::new(1, 1_000),
318            flush_interval: Self::DEFAULT_FLUSH_INTERVAL,
319            step_count: 0,
320            tracks_done: false,
321            stats: MultiTrackStats::default(),
322        }
323    }
324
325    /// Adds a [`PerTrack`] to the executor.
326    pub fn add_track(&mut self, track: PerTrack) {
327        self.tracks.push(track);
328    }
329
330    /// Overrides the heap flush interval (default: 30 steps).
331    pub fn set_flush_interval(&mut self, n: u64) {
332        self.flush_interval = n.max(1);
333    }
334
335    /// Returns a shared reference to the inner muxer.
336    #[must_use]
337    pub fn muxer(&self) -> &M {
338        &self.muxer
339    }
340
341    /// Consumes the executor and returns the inner muxer.
342    #[must_use]
343    pub fn into_muxer(self) -> M {
344        self.muxer
345    }
346
347    /// Returns the accumulated execution statistics.
348    #[must_use]
349    pub fn stats(&self) -> &MultiTrackStats {
350        &self.stats
351    }
352
353    // ── Internal helpers ──────────────────────────────────────────────────────
354
355    /// Push a [`TrackEncoded`] result onto the DTS heap.
356    fn push_to_heap(&mut self, stream_index: usize, encoded: TrackEncoded) {
357        let packet = StagedPacket {
358            dts_ms: encoded.pts_ms,
359            stream_index,
360            data: encoded.data,
361            is_audio: encoded.is_audio,
362        };
363        self.heap.push(Reverse(packet));
364    }
365
366    /// Drain all packets in the heap to the muxer in DTS order.
367    async fn drain_heap_to_muxer(&mut self) -> Result<()> {
368        while let Some(Reverse(staged)) = self.heap.pop() {
369            self.write_staged_packet(staged).await?;
370        }
371        Ok(())
372    }
373
374    /// Drain heap packets whose DTS is strictly less than `horizon_ms`.
375    ///
376    /// This "safe drain" strategy ensures packets behind the current minimum
377    /// active DTS are flushed promptly, while packets that might still be
378    /// overtaken by a slower track are retained.
379    async fn drain_heap_until(&mut self, horizon_ms: i64) -> Result<()> {
380        loop {
381            match self.heap.peek() {
382                Some(Reverse(staged)) if staged.dts_ms < horizon_ms => {
383                    let Reverse(pkt) = self.heap.pop().expect("non-empty after peek");
384                    self.write_staged_packet(pkt).await?;
385                }
386                _ => break,
387            }
388        }
389        Ok(())
390    }
391
392    /// Write a single [`StagedPacket`] to the muxer.
393    async fn write_staged_packet(&mut self, staged: StagedPacket) -> Result<()> {
394        let ts = Timestamp::new(staged.dts_ms, self.timebase);
395        let flags = if staged.is_audio {
396            PacketFlags::empty()
397        } else {
398            PacketFlags::KEYFRAME
399        };
400        let pkt = Packet::new(staged.stream_index, Bytes::from(staged.data), ts, flags);
401        self.muxer.write_packet(&pkt).await.map_err(|e| {
402            TranscodeError::ContainerError(format!("muxer write_packet failed: {e}"))
403        })?;
404        self.stats.packets_muxed += 1;
405        Ok(())
406    }
407
408    // ── Public API ────────────────────────────────────────────────────────────
409
410    /// Perform one step of the pipeline: attempt to decode one frame from
411    /// every active track, filter and encode it, then push the result onto
412    /// the DTS heap.
413    ///
414    /// Periodically drains the heap to the muxer based on the minimum active
415    /// DTS (safe-drain strategy).
416    ///
417    /// Returns `true` if at least one track produced an encoded packet this
418    /// step, `false` when all tracks are exhausted.
419    ///
420    /// # Errors
421    ///
422    /// Propagates errors from the filter graph, encoder, or muxer.
423    pub async fn step(&mut self) -> Result<bool> {
424        if self.tracks_done {
425            return Ok(false);
426        }
427
428        // Collect encoded output from all tracks before mutating `self.heap`.
429        // This avoids a double-borrow of `self` when `push_to_heap` is called
430        // inside the loop that also borrows `self.tracks`.
431        let mut pending: Vec<(usize, TrackEncoded)> = Vec::new();
432        let mut min_active_dts: Option<i64> = None;
433
434        for track in &mut self.tracks {
435            if track.flushed || track.decoder.eof() {
436                continue;
437            }
438
439            if let Some(encoded) = track.step_frame()? {
440                let dts = encoded.pts_ms;
441                min_active_dts = Some(match min_active_dts {
442                    Some(prev) => prev.min(dts),
443                    None => dts,
444                });
445                pending.push((track.stream_index, encoded));
446            }
447        }
448
449        let any_produced = !pending.is_empty();
450        let encoded_this_step = pending.len() as u64;
451
452        // Push collected results onto the DTS heap.
453        for (stream_index, encoded) in pending {
454            self.push_to_heap(stream_index, encoded);
455        }
456
457        // Aggregate byte stats.
458        self.stats.total_encoded_bytes = self.tracks.iter().map(|t| t.encoded_bytes).sum();
459        self.stats.total_encoded_frames += encoded_this_step;
460
461        self.step_count += 1;
462
463        // Safe-drain the heap on schedule.
464        if self.step_count % self.flush_interval == 0 {
465            if let Some(horizon) = min_active_dts {
466                self.drain_heap_until(horizon).await?;
467            }
468        }
469
470        // Update done flag.
471        let all_done = self.tracks.iter().all(|t| t.decoder.eof() || t.flushed);
472        if all_done {
473            self.tracks_done = true;
474        }
475
476        Ok(any_produced)
477    }
478
479    /// Execute the full pipeline end-to-end.
480    ///
481    /// 1. Registers `streams` with the muxer and writes the header.
482    /// 2. Calls [`step`](Self::step) in a loop until all tracks are exhausted.
483    /// 3. Flushes each track's encoder.
484    /// 4. Drains the remaining heap to the muxer in DTS order.
485    /// 5. Writes the muxer trailer.
486    ///
487    /// Returns accumulated [`MultiTrackStats`].
488    ///
489    /// # Errors
490    ///
491    /// Returns an error if any stage (filter, encode, mux header/packet/trailer)
492    /// fails.
493    pub async fn execute(&mut self, streams: &[StreamInfo]) -> Result<MultiTrackStats> {
494        // Register streams with the muxer.
495        for stream in streams {
496            self.muxer
497                .add_stream(stream.clone())
498                .map_err(|e| TranscodeError::ContainerError(format!("add_stream failed: {e}")))?;
499        }
500
501        self.muxer
502            .write_header()
503            .await
504            .map_err(|e| TranscodeError::ContainerError(format!("write_header failed: {e}")))?;
505
506        // Main decode/encode loop.
507        loop {
508            let produced = self.step().await?;
509            if self.tracks_done {
510                break;
511            }
512            if !produced {
513                // No track produced a packet — check whether they are all at EOF.
514                let all_eof = self.tracks.iter().all(|t| t.decoder.eof() || t.flushed);
515                if all_eof {
516                    self.tracks_done = true;
517                    break;
518                }
519            }
520        }
521
522        // Flush each encoder.
523        for idx in 0..self.tracks.len() {
524            let stream_index = self.tracks[idx].stream_index;
525            if let Some(encoded) = self.tracks[idx].flush_encoder()? {
526                self.push_to_heap(stream_index, encoded);
527            }
528        }
529
530        // Final full heap drain in DTS order.
531        self.drain_heap_to_muxer().await?;
532
533        // Finalise stats.
534        self.stats.total_encoded_bytes = self.tracks.iter().map(|t| t.encoded_bytes).sum();
535        self.stats.total_encoded_frames = self.tracks.iter().map(|t| t.encoded_frames).sum();
536
537        self.muxer
538            .write_trailer()
539            .await
540            .map_err(|e| TranscodeError::ContainerError(format!("write_trailer failed: {e}")))?;
541
542        Ok(self.stats.clone())
543    }
544}