oximedia_transcode/multi_track.rs
1//! Multi-track frame-level pipeline executor with DTS-ordered interleaving.
2//!
3//! This module provides [`MultiTrackExecutor`] — the real decode → filter →
4//! encode engine for `oximedia-transcode`. It connects the `FrameDecoder` /
5//! `FilterGraph` / `FrameEncoder` plumbing from [`pipeline_context`] to a
6//! container-level output via a [`Muxer`] and performs *DTS-ordered
7//! interleaving* across all tracks using a min-heap.
8//!
9//! # Architecture
10//!
11//! ```text
12//! [FrameDecoder₀] → FilterGraph₀ → [FrameEncoder₀] ─┐
13//! [FrameDecoder₁] → FilterGraph₁ → [FrameEncoder₁] ─┤→ DTS min-heap → Muxer
14//! … ┘
15//! ```
16//!
17//! ## Execute loop
18//!
19//! 1. For each active track, call [`FrameDecoder::decode_next`].
20//! 2. Apply the track's [`FilterGraph::apply`] to the decoded frame.
21//! 3. Pass filtered frames to the track's [`FrameEncoder::encode_frame`].
22//! 4. Push resulting encoded bytes as a `StagedPacket` onto the DTS min-heap.
23//! 5. After all tracks are exhausted, flush each encoder.
24//! 6. Pop the heap in DTS order and write every packet to the [`Muxer`].
25//!
26//! The [`MultiTrackExecutor::step`] method performs one packet-cycle (one pass
27//! through all tracks) and pushes ready encoded data into the internal staging
28//! buffer, so a segment or parallel driver can call it externally.
29//!
30//! [`pipeline_context`]: crate::pipeline_context
31
32#![allow(clippy::module_name_repetitions)]
33
34use std::cmp::Reverse;
35use std::collections::BinaryHeap;
36
37use bytes::Bytes;
38use oximedia_container::{Muxer, Packet, PacketFlags, StreamInfo};
39use oximedia_core::{Rational, Timestamp};
40
41use crate::pipeline_context::{FilterGraph, FrameDecoder, FrameEncoder};
42use crate::{Result, TranscodeError};
43
44// ─── PerTrack ─────────────────────────────────────────────────────────────────
45
46/// One logical media track wired through decode → filter → encode.
47///
48/// Created by the caller with concrete decoder, filter graph, and encoder
49/// implementations, then handed to [`MultiTrackExecutor::add_track`].
50pub struct PerTrack {
51 /// The stream index in the output muxer for packets from this track.
52 pub stream_index: usize,
53 /// Decoder for this track.
54 pub decoder: Box<dyn FrameDecoder>,
55 /// Filter graph applied between decode and encode.
56 pub filter_graph: FilterGraph,
57 /// Encoder for this track.
58 pub encoder: Box<dyn FrameEncoder>,
59 /// `true` when the decoder has reported EOF and the encoder has been flushed.
60 pub flushed: bool,
61 /// Frame counter; used to derive synthetic PTS for flush packets.
62 frame_count: u64,
63 /// Accumulated encoded-bytes count (public for stats queries).
64 pub encoded_bytes: u64,
65 /// Accumulated encoded-frame count (public for stats queries).
66 pub encoded_frames: u64,
67 /// Whether this track carries audio (`true`) or video (`false`).
68 ///
69 /// Determined from the first decoded frame and used to populate the
70 /// `is_audio` flag on flush tail-packets — ensuring the muxer receives
71 /// correct stream-type information at EOS even when no frames flow.
72 is_audio: Option<bool>,
73}
74
75impl PerTrack {
76 /// Create a new [`PerTrack`] with the given stream index, decoder,
77 /// filter graph, and encoder.
78 #[must_use]
79 pub fn new(
80 stream_index: usize,
81 decoder: Box<dyn FrameDecoder>,
82 filter_graph: FilterGraph,
83 encoder: Box<dyn FrameEncoder>,
84 ) -> Self {
85 Self {
86 stream_index,
87 decoder,
88 filter_graph,
89 encoder,
90 flushed: false,
91 frame_count: 0,
92 encoded_bytes: 0,
93 encoded_frames: 0,
94 is_audio: None,
95 }
96 }
97
98 /// Create a new [`PerTrack`] whose track type is known at construction.
99 ///
100 /// Use this constructor when the stream kind (audio vs video) is
101 /// available from the container's [`StreamInfo`] before decoding starts,
102 /// so that `flush_encoder` emits packets with the correct type even if
103 /// no frames were decoded (e.g., a very short audio track).
104 #[must_use]
105 pub fn new_typed(
106 stream_index: usize,
107 decoder: Box<dyn FrameDecoder>,
108 filter_graph: FilterGraph,
109 encoder: Box<dyn FrameEncoder>,
110 is_audio: bool,
111 ) -> Self {
112 Self {
113 stream_index,
114 decoder,
115 filter_graph,
116 encoder,
117 flushed: false,
118 frame_count: 0,
119 encoded_bytes: 0,
120 encoded_frames: 0,
121 is_audio: Some(is_audio),
122 }
123 }
124
125 /// Step this track by one frame: decode → filter → encode.
126 ///
127 /// Returns `Ok(Some(TrackEncoded))` when a frame was successfully encoded,
128 /// `Ok(None)` when the decoder produced no frame (EOF or frame dropped by
129 /// filter), or an error if encoding or filter operations fail.
130 fn step_frame(&mut self) -> Result<Option<TrackEncoded>> {
131 if self.flushed || self.decoder.eof() {
132 return Ok(None);
133 }
134
135 let frame = match self.decoder.decode_next() {
136 Some(f) => f,
137 None => return Ok(None),
138 };
139
140 let pts_ms = frame.pts_ms;
141 let is_audio = frame.is_audio;
142
143 // Latch the track kind from the first frame so flush_encoder can use it.
144 if self.is_audio.is_none() {
145 self.is_audio = Some(is_audio);
146 }
147
148 let filtered = match self.filter_graph.apply(frame)? {
149 Some(f) => f,
150 None => {
151 // Frame dropped by filter — counts as dropped, not an error.
152 return Ok(None);
153 }
154 };
155
156 let encoded = self.encoder.encode_frame(&filtered)?;
157 let n = encoded.len() as u64;
158 self.encoded_bytes += n;
159 self.encoded_frames += 1;
160 self.frame_count += 1;
161
162 Ok(Some(TrackEncoded {
163 data: encoded,
164 pts_ms,
165 is_audio,
166 }))
167 }
168
169 /// Flush the encoder and return remaining encoded bytes (if any).
170 ///
171 /// Sets `self.flushed = true` after the first call; subsequent calls are
172 /// no-ops.
173 ///
174 /// The `is_audio` flag on the returned tail-packet is taken from the track
175 /// type latched during [`step_frame`](Self::step_frame) (or set at
176 /// construction via [`new_typed`](Self::new_typed)). If neither path has
177 /// provided a type yet (a zero-frame track), the flush packet is omitted
178 /// entirely since there is no stream kind to report.
179 fn flush_encoder(&mut self) -> Result<Option<TrackEncoded>> {
180 if self.flushed {
181 return Ok(None);
182 }
183 self.flushed = true;
184 let data = self.encoder.flush()?;
185 if data.is_empty() {
186 return Ok(None);
187 }
188 // If the track type is still unknown (zero-frame track), skip the
189 // flush packet rather than reporting a wrong stream kind to the muxer.
190 let is_audio = match self.is_audio {
191 Some(v) => v,
192 None => return Ok(None),
193 };
194 self.encoded_bytes += data.len() as u64;
195 // Derive a synthetic PTS from the frame count (33 ms/frame ≈ 30 fps).
196 let pts_ms = self.frame_count as i64 * 33;
197 Ok(Some(TrackEncoded {
198 data,
199 pts_ms,
200 is_audio,
201 }))
202 }
203}
204
205// ─── TrackEncoded ─────────────────────────────────────────────────────────────
206
207/// Encoded output produced by a single [`PerTrack::step_frame`] call.
208#[derive(Debug)]
209struct TrackEncoded {
210 data: Vec<u8>,
211 pts_ms: i64,
212 is_audio: bool,
213}
214
215// ─── StagedPacket ─────────────────────────────────────────────────────────────
216
217/// An encoded packet waiting in the DTS min-heap for muxer output.
218#[derive(Debug)]
219struct StagedPacket {
220 /// Effective DTS for heap ordering.
221 dts_ms: i64,
222 /// Stream index for the muxer.
223 stream_index: usize,
224 /// Encoded payload.
225 data: Vec<u8>,
226 /// `true` for audio packets.
227 is_audio: bool,
228}
229
230impl PartialEq for StagedPacket {
231 fn eq(&self, other: &Self) -> bool {
232 self.dts_ms == other.dts_ms && self.stream_index == other.stream_index
233 }
234}
235
236impl Eq for StagedPacket {}
237
238impl PartialOrd for StagedPacket {
239 fn partial_cmp(&self, other: &Self) -> Option<std::cmp::Ordering> {
240 Some(self.cmp(other))
241 }
242}
243
244impl Ord for StagedPacket {
245 fn cmp(&self, other: &Self) -> std::cmp::Ordering {
246 // Primary: DTS ascending; secondary: stream_index ascending for determinism.
247 self.dts_ms
248 .cmp(&other.dts_ms)
249 .then(self.stream_index.cmp(&other.stream_index))
250 }
251}
252
253// ─── MultiTrackStats ──────────────────────────────────────────────────────────
254
255/// Statistics returned by [`MultiTrackExecutor::execute`].
256#[derive(Debug, Clone, Default)]
257pub struct MultiTrackStats {
258 /// Total encoded frames across all tracks.
259 pub total_encoded_frames: u64,
260 /// Total encoded bytes across all tracks.
261 pub total_encoded_bytes: u64,
262 /// Number of packets written to the muxer in DTS order.
263 pub packets_muxed: u64,
264 /// Number of frames dropped by filter graphs.
265 pub frames_dropped: u64,
266}
267
268// ─── MultiTrackExecutor ───────────────────────────────────────────────────────
269
270/// Frame-level multi-track decode → filter → encode executor with
271/// DTS-ordered muxing.
272///
273/// # Usage
274///
275/// ```rust,ignore
276/// use oximedia_transcode::multi_track::{MultiTrackExecutor, PerTrack};
277/// use oximedia_transcode::pipeline_context::{FilterGraph, Frame};
278///
279/// // Supply concrete FrameDecoder / FrameEncoder implementations:
280/// let mut executor = MultiTrackExecutor::new(muxer);
281/// executor.add_track(PerTrack::new(0, decoder0, FilterGraph::new(), encoder0));
282/// executor.add_track(PerTrack::new(1, decoder1, FilterGraph::new(), encoder1));
283/// let stats = executor.execute(&streams).await?;
284/// ```
285pub struct MultiTrackExecutor<M: Muxer> {
286 /// Per-track decode/filter/encode pipelines.
287 tracks: Vec<PerTrack>,
288 /// The output container muxer.
289 muxer: M,
290 /// DTS min-heap: `Reverse` turns `BinaryHeap` (max-heap) into a min-heap.
291 heap: BinaryHeap<Reverse<StagedPacket>>,
292 /// Timebase used for `Timestamp` construction (1 ms resolution by default).
293 timebase: Rational,
294 /// Drain the heap to the muxer every `flush_interval` step cycles.
295 flush_interval: u64,
296 /// Step counter for flush scheduling.
297 step_count: u64,
298 /// `true` after all tracks have reached EOF.
299 tracks_done: bool,
300 /// Accumulated statistics.
301 stats: MultiTrackStats,
302}
303
304impl<M: Muxer> MultiTrackExecutor<M> {
305 /// Default flush interval (drain heap every 30 steps).
306 const DEFAULT_FLUSH_INTERVAL: u64 = 30;
307
308 /// Creates a new executor wrapping `muxer`.
309 ///
310 /// Tracks must be added with [`add_track`](Self::add_track) before calling
311 /// [`execute`](Self::execute) or [`step`](Self::step).
312 pub fn new(muxer: M) -> Self {
313 Self {
314 tracks: Vec::new(),
315 muxer,
316 heap: BinaryHeap::new(),
317 timebase: Rational::new(1, 1_000),
318 flush_interval: Self::DEFAULT_FLUSH_INTERVAL,
319 step_count: 0,
320 tracks_done: false,
321 stats: MultiTrackStats::default(),
322 }
323 }
324
325 /// Adds a [`PerTrack`] to the executor.
326 pub fn add_track(&mut self, track: PerTrack) {
327 self.tracks.push(track);
328 }
329
330 /// Overrides the heap flush interval (default: 30 steps).
331 pub fn set_flush_interval(&mut self, n: u64) {
332 self.flush_interval = n.max(1);
333 }
334
335 /// Returns a shared reference to the inner muxer.
336 #[must_use]
337 pub fn muxer(&self) -> &M {
338 &self.muxer
339 }
340
341 /// Consumes the executor and returns the inner muxer.
342 #[must_use]
343 pub fn into_muxer(self) -> M {
344 self.muxer
345 }
346
347 /// Returns the accumulated execution statistics.
348 #[must_use]
349 pub fn stats(&self) -> &MultiTrackStats {
350 &self.stats
351 }
352
353 // ── Internal helpers ──────────────────────────────────────────────────────
354
355 /// Push a [`TrackEncoded`] result onto the DTS heap.
356 fn push_to_heap(&mut self, stream_index: usize, encoded: TrackEncoded) {
357 let packet = StagedPacket {
358 dts_ms: encoded.pts_ms,
359 stream_index,
360 data: encoded.data,
361 is_audio: encoded.is_audio,
362 };
363 self.heap.push(Reverse(packet));
364 }
365
366 /// Drain all packets in the heap to the muxer in DTS order.
367 async fn drain_heap_to_muxer(&mut self) -> Result<()> {
368 while let Some(Reverse(staged)) = self.heap.pop() {
369 self.write_staged_packet(staged).await?;
370 }
371 Ok(())
372 }
373
374 /// Drain heap packets whose DTS is strictly less than `horizon_ms`.
375 ///
376 /// This "safe drain" strategy ensures packets behind the current minimum
377 /// active DTS are flushed promptly, while packets that might still be
378 /// overtaken by a slower track are retained.
379 async fn drain_heap_until(&mut self, horizon_ms: i64) -> Result<()> {
380 loop {
381 match self.heap.peek() {
382 Some(Reverse(staged)) if staged.dts_ms < horizon_ms => {
383 let Reverse(pkt) = self.heap.pop().expect("non-empty after peek");
384 self.write_staged_packet(pkt).await?;
385 }
386 _ => break,
387 }
388 }
389 Ok(())
390 }
391
392 /// Write a single [`StagedPacket`] to the muxer.
393 async fn write_staged_packet(&mut self, staged: StagedPacket) -> Result<()> {
394 let ts = Timestamp::new(staged.dts_ms, self.timebase);
395 let flags = if staged.is_audio {
396 PacketFlags::empty()
397 } else {
398 PacketFlags::KEYFRAME
399 };
400 let pkt = Packet::new(staged.stream_index, Bytes::from(staged.data), ts, flags);
401 self.muxer.write_packet(&pkt).await.map_err(|e| {
402 TranscodeError::ContainerError(format!("muxer write_packet failed: {e}"))
403 })?;
404 self.stats.packets_muxed += 1;
405 Ok(())
406 }
407
408 // ── Public API ────────────────────────────────────────────────────────────
409
410 /// Perform one step of the pipeline: attempt to decode one frame from
411 /// every active track, filter and encode it, then push the result onto
412 /// the DTS heap.
413 ///
414 /// Periodically drains the heap to the muxer based on the minimum active
415 /// DTS (safe-drain strategy).
416 ///
417 /// Returns `true` if at least one track produced an encoded packet this
418 /// step, `false` when all tracks are exhausted.
419 ///
420 /// # Errors
421 ///
422 /// Propagates errors from the filter graph, encoder, or muxer.
423 pub async fn step(&mut self) -> Result<bool> {
424 if self.tracks_done {
425 return Ok(false);
426 }
427
428 // Collect encoded output from all tracks before mutating `self.heap`.
429 // This avoids a double-borrow of `self` when `push_to_heap` is called
430 // inside the loop that also borrows `self.tracks`.
431 let mut pending: Vec<(usize, TrackEncoded)> = Vec::new();
432 let mut min_active_dts: Option<i64> = None;
433
434 for track in &mut self.tracks {
435 if track.flushed || track.decoder.eof() {
436 continue;
437 }
438
439 if let Some(encoded) = track.step_frame()? {
440 let dts = encoded.pts_ms;
441 min_active_dts = Some(match min_active_dts {
442 Some(prev) => prev.min(dts),
443 None => dts,
444 });
445 pending.push((track.stream_index, encoded));
446 }
447 }
448
449 let any_produced = !pending.is_empty();
450 let encoded_this_step = pending.len() as u64;
451
452 // Push collected results onto the DTS heap.
453 for (stream_index, encoded) in pending {
454 self.push_to_heap(stream_index, encoded);
455 }
456
457 // Aggregate byte stats.
458 self.stats.total_encoded_bytes = self.tracks.iter().map(|t| t.encoded_bytes).sum();
459 self.stats.total_encoded_frames += encoded_this_step;
460
461 self.step_count += 1;
462
463 // Safe-drain the heap on schedule.
464 if self.step_count % self.flush_interval == 0 {
465 if let Some(horizon) = min_active_dts {
466 self.drain_heap_until(horizon).await?;
467 }
468 }
469
470 // Update done flag.
471 let all_done = self.tracks.iter().all(|t| t.decoder.eof() || t.flushed);
472 if all_done {
473 self.tracks_done = true;
474 }
475
476 Ok(any_produced)
477 }
478
479 /// Execute the full pipeline end-to-end.
480 ///
481 /// 1. Registers `streams` with the muxer and writes the header.
482 /// 2. Calls [`step`](Self::step) in a loop until all tracks are exhausted.
483 /// 3. Flushes each track's encoder.
484 /// 4. Drains the remaining heap to the muxer in DTS order.
485 /// 5. Writes the muxer trailer.
486 ///
487 /// Returns accumulated [`MultiTrackStats`].
488 ///
489 /// # Errors
490 ///
491 /// Returns an error if any stage (filter, encode, mux header/packet/trailer)
492 /// fails.
493 pub async fn execute(&mut self, streams: &[StreamInfo]) -> Result<MultiTrackStats> {
494 // Register streams with the muxer.
495 for stream in streams {
496 self.muxer
497 .add_stream(stream.clone())
498 .map_err(|e| TranscodeError::ContainerError(format!("add_stream failed: {e}")))?;
499 }
500
501 self.muxer
502 .write_header()
503 .await
504 .map_err(|e| TranscodeError::ContainerError(format!("write_header failed: {e}")))?;
505
506 // Main decode/encode loop.
507 loop {
508 let produced = self.step().await?;
509 if self.tracks_done {
510 break;
511 }
512 if !produced {
513 // No track produced a packet — check whether they are all at EOF.
514 let all_eof = self.tracks.iter().all(|t| t.decoder.eof() || t.flushed);
515 if all_eof {
516 self.tracks_done = true;
517 break;
518 }
519 }
520 }
521
522 // Flush each encoder.
523 for idx in 0..self.tracks.len() {
524 let stream_index = self.tracks[idx].stream_index;
525 if let Some(encoded) = self.tracks[idx].flush_encoder()? {
526 self.push_to_heap(stream_index, encoded);
527 }
528 }
529
530 // Final full heap drain in DTS order.
531 self.drain_heap_to_muxer().await?;
532
533 // Finalise stats.
534 self.stats.total_encoded_bytes = self.tracks.iter().map(|t| t.encoded_bytes).sum();
535 self.stats.total_encoded_frames = self.tracks.iter().map(|t| t.encoded_frames).sum();
536
537 self.muxer
538 .write_trailer()
539 .await
540 .map_err(|e| TranscodeError::ContainerError(format!("write_trailer failed: {e}")))?;
541
542 Ok(self.stats.clone())
543 }
544}