Skip to main content

container/
nal_mux.rs

1//! Mux-side H.264 / H.265 NAL handling: take the encoder's **Annex-B** output
2//! (start-code-delimited NAL units), strip the out-of-band parameter sets
3//! (SPS/PPS, plus HEVC VPS) for the `avcC`/`hvcC` config box, and repackage the
4//! remaining NALs (slices, SEI) as **length-prefixed** (4-byte) samples for the
5//! MP4 `mdat`. This is the inverse of the demux path in
6//! [`annexb`](crate::annexb), which reads length-prefixed → Annex-B.
7//!
8//! `avc1`/`hvc1` carry the parameter sets in the sample-entry config box, not
9//! in-band, so the per-sample data must NOT repeat them.
10
11/// Which NAL codec the bitstream is.
12#[derive(Debug, Clone, Copy, PartialEq, Eq)]
13pub enum NalMuxCodec {
14    H264,
15    H265,
16}
17
18/// What a NAL unit is, for the mux split.
19#[derive(Debug, Clone, Copy, PartialEq, Eq)]
20enum NalClass {
21    Vps,
22    Sps,
23    Pps,
24    /// Slice / SEI / AUD / anything else that belongs in the sample data.
25    Sample,
26}
27
28/// `nal_unit_type` for the given codec (0 for an empty NAL).
29fn nal_type(nal: &[u8], codec: NalMuxCodec) -> u8 {
30    if nal.is_empty() {
31        return 0;
32    }
33    match codec {
34        NalMuxCodec::H264 => nal[0] & 0x1F,           // H.264 §7.3.1
35        NalMuxCodec::H265 => (nal[0] >> 1) & 0x3F,    // H.265 §7.3.1.2 (2-byte header)
36    }
37}
38
39/// Classify a NAL unit (payload only, no start code) for the given codec.
40fn classify(nal: &[u8], codec: NalMuxCodec) -> NalClass {
41    match (codec, nal_type(nal, codec)) {
42        (NalMuxCodec::H264, 7) => NalClass::Sps,
43        (NalMuxCodec::H264, 8) => NalClass::Pps,
44        (NalMuxCodec::H265, 32) => NalClass::Vps,
45        (NalMuxCodec::H265, 33) => NalClass::Sps,
46        (NalMuxCodec::H265, 34) => NalClass::Pps,
47        _ => NalClass::Sample,
48    }
49}
50
51/// Access-unit delimiter (H.264 type 9 / H.265 type 35) — starts a new frame.
52fn is_aud(nal: &[u8], codec: NalMuxCodec) -> bool {
53    match codec {
54        NalMuxCodec::H264 => nal_type(nal, codec) == 9,
55        NalMuxCodec::H265 => nal_type(nal, codec) == 35,
56    }
57}
58
59/// Whether this NAL is an IDR / IRAP slice (a keyframe's VCL NAL).
60fn is_idr(nal: &[u8], codec: NalMuxCodec) -> bool {
61    match codec {
62        NalMuxCodec::H264 => nal_type(nal, codec) == 5,              // IDR slice
63        NalMuxCodec::H265 => matches!(nal_type(nal, codec), 16..=23), // BLA..IRAP
64    }
65}
66
67/// Whether this NAL is a VCL (slice) NAL.
68fn is_vcl(nal: &[u8], codec: NalMuxCodec) -> bool {
69    let t = nal_type(nal, codec);
70    match codec {
71        NalMuxCodec::H264 => (1..=5).contains(&t),
72        NalMuxCodec::H265 => t <= 31,
73    }
74}
75
76/// Whether a VCL slice begins a new picture — the access-unit boundary signal
77/// when the encoder emits no AUD. H.264: `first_mb_in_slice == 0` ⟺ the slice
78/// header's leading `ue(v)` is the single bit `1` (top bit set). H.265:
79/// `first_slice_segment_in_pic_flag` is the first bit after the 2-byte header.
80fn first_slice_in_pic(nal: &[u8], codec: NalMuxCodec) -> bool {
81    match codec {
82        NalMuxCodec::H264 => nal.len() > 1 && (nal[1] & 0x80) != 0,
83        NalMuxCodec::H265 => nal.len() > 2 && (nal[2] & 0x80) != 0,
84    }
85}
86
87/// One muxed access unit (frame): its length-prefixed sample bytes + whether
88/// it is a keyframe.
89#[derive(Debug, Clone)]
90pub struct AuSample {
91    pub data: Vec<u8>,
92    pub is_keyframe: bool,
93}
94
95/// Split an Annex-B buffer into its NAL units (payloads, start codes removed).
96/// Handles both 3-byte (`00 00 01`) and 4-byte (`00 00 00 01`) start codes.
97pub fn split_annexb_nals(data: &[u8]) -> Vec<&[u8]> {
98    let mut nals = Vec::new();
99    let n = data.len();
100    // Position just past the first start code.
101    let mut cursor = match find_start_code(data, 0) {
102        Some((pos, len)) => pos + len,
103        None => return nals, // no start code → not Annex-B / empty
104    };
105    loop {
106        // `find_start_code` reports a 4-byte start code at its first `00`, so the
107        // NAL ends exactly at the next start code — legitimate trailing zero
108        // bytes in the slice RBSP (cabac_zero_words, rbsp trailing) are kept.
109        let (next_pos, next_len) = match find_start_code(data, cursor) {
110            Some(x) => x,
111            None => {
112                if n > cursor {
113                    nals.push(&data[cursor..n]); // last NAL runs to the end
114                }
115                break;
116            }
117        };
118        if next_pos > cursor {
119            nals.push(&data[cursor..next_pos]);
120        }
121        cursor = next_pos + next_len;
122    }
123    nals
124}
125
126/// Find the next start-code **prefix** `00 00 01` at/after `from`; returns
127/// (offset, 3). We deliberately match only the 3-byte prefix: a 4-byte start
128/// code `00 00 00 01` is then seen as `[zero_byte] [00 00 01]`, so the leading
129/// `00` stays with the *previous* NAL as a harmless trailing zero (decoders
130/// ignore it) rather than being greedily consumed — which would otherwise eat a
131/// slice's own trailing `0x00` byte and corrupt it.
132fn find_start_code(data: &[u8], from: usize) -> Option<(usize, usize)> {
133    let n = data.len();
134    let mut i = from;
135    while i + 3 <= n {
136        if data[i] == 0 && data[i + 1] == 0 && data[i + 2] == 1 {
137            return Some((i, 3));
138        }
139        i += 1;
140    }
141    None
142}
143
144/// Repackages Annex-B encoder frames into length-prefixed mdat samples while
145/// collecting the parameter sets for the `avcC`/`hvcC` config box.
146///
147/// Two modes:
148/// - **out-of-band** (default): SPS/PPS/VPS are stripped from samples and stored
149///   in the config box. Correct for a single encoder (`avc1`/`hvc1`).
150/// - **inline** ([`new_inline`]): SPS/PPS/VPS are ALSO kept inline in each
151///   access unit (each IDR self-describes). Used by the multi-GPU stitch, where
152///   chunks come from independent encoders (possibly different vendors): the
153///   inline parameter sets let each chunk decode with its own SPS/PPS even if
154///   they differ cosmetically. Pairs with the `avc3`/`hev1` sample entry. The
155///   config box still gets the FIRST set as a default hint.
156#[derive(Debug)]
157pub struct NalSampleWriter {
158    codec: NalMuxCodec,
159    /// HEVC VPS NAL units (empty for H.264), first-seen order, de-duplicated.
160    pub vps: Vec<Vec<u8>>,
161    pub sps: Vec<Vec<u8>>,
162    pub pps: Vec<Vec<u8>>,
163    inline_param_sets: bool,
164}
165
166impl NalSampleWriter {
167    pub fn new(codec: NalMuxCodec) -> Self {
168        Self { codec, vps: Vec::new(), sps: Vec::new(), pps: Vec::new(), inline_param_sets: false }
169    }
170
171    /// Inline-parameter-set mode (for the multi-GPU stitch). Keeps SPS/PPS/VPS
172    /// inline in each access unit AND records the first set for the config box.
173    pub fn new_inline(codec: NalMuxCodec) -> Self {
174        Self { codec, vps: Vec::new(), sps: Vec::new(), pps: Vec::new(), inline_param_sets: true }
175    }
176
177    /// Convert one encoder packet — which may carry **multiple access units**
178    /// (HW encoders return several frames per buffer) — into one
179    /// **length-prefixed** mdat sample *per access unit*. Access units are
180    /// delimited by the AUD NAL (a packet with no AUD is treated as one unit).
181    /// SPS/PPS/VPS are captured (for the config box) and stripped from samples.
182    pub fn push_packet(&mut self, annexb: &[u8]) -> Vec<AuSample> {
183        // Group NALs into access units. A new unit begins at an AUD, or — when
184        // the encoder emits no AUD (QSV H.265) — at the first VCL slice of a new
185        // picture once the current unit already holds a slice.
186        let mut units: Vec<Vec<&[u8]>> = vec![Vec::new()];
187        let mut cur_has_vcl = false;
188        for nal in split_annexb_nals(annexb) {
189            let new_au = is_aud(nal, self.codec)
190                || (is_vcl(nal, self.codec) && cur_has_vcl && first_slice_in_pic(nal, self.codec));
191            if new_au && !units.last().unwrap().is_empty() {
192                units.push(Vec::new());
193                cur_has_vcl = false;
194            }
195            if is_vcl(nal, self.codec) {
196                cur_has_vcl = true;
197            }
198            units.last_mut().unwrap().push(nal);
199        }
200
201        let codec = self.codec;
202        let inline = self.inline_param_sets;
203        let mut samples = Vec::new();
204        for unit in units {
205            let mut data = Vec::new();
206            let mut is_keyframe = false;
207            for nal in unit {
208                let push_inline = |data: &mut Vec<u8>| {
209                    data.extend_from_slice(&(nal.len() as u32).to_be_bytes());
210                    data.extend_from_slice(nal);
211                };
212                match classify(nal, codec) {
213                    NalClass::Sample => {
214                        if is_idr(nal, codec) {
215                            is_keyframe = true;
216                        }
217                        push_inline(&mut data);
218                        continue;
219                    }
220                    NalClass::Vps | NalClass::Sps | NalClass::Pps => {}
221                }
222                // A parameter set (SPS/PPS/VPS):
223                let store = match classify(nal, codec) {
224                    NalClass::Vps => &mut self.vps,
225                    NalClass::Sps => &mut self.sps,
226                    NalClass::Pps => &mut self.pps,
227                    NalClass::Sample => unreachable!(),
228                };
229                if inline {
230                    // Record the first of each kind for the config-box default,
231                    // and keep every parameter set inline in the access unit.
232                    if store.is_empty() {
233                        store.push(nal.to_vec());
234                    }
235                    push_inline(&mut data);
236                } else {
237                    dedup_push(store, nal);
238                }
239            }
240            if !data.is_empty() {
241                samples.push(AuSample { data, is_keyframe });
242            }
243        }
244        samples
245    }
246
247    /// Whether the parameter sets needed for the config box have been seen.
248    pub fn has_param_sets(&self) -> bool {
249        let vps_ok = matches!(self.codec, NalMuxCodec::H264) || !self.vps.is_empty();
250        vps_ok && !self.sps.is_empty() && !self.pps.is_empty()
251    }
252}
253
254fn dedup_push(set: &mut Vec<Vec<u8>>, nal: &[u8]) {
255    if !set.iter().any(|n| n.as_slice() == nal) {
256        set.push(nal.to_vec());
257    }
258}
259
260#[cfg(test)]
261mod tests {
262    use super::*;
263
264    fn sc4(nal: &[u8]) -> Vec<u8> {
265        let mut v = vec![0, 0, 0, 1];
266        v.extend_from_slice(nal);
267        v
268    }
269
270    #[test]
271    fn splits_3_and_4_byte_start_codes() {
272        // 4-byte SC, then 3-byte SC
273        let mut buf = vec![0, 0, 0, 1, 0xAA, 0xBB];
274        buf.extend_from_slice(&[0, 0, 1, 0xCC]);
275        let nals = split_annexb_nals(&buf);
276        assert_eq!(nals.len(), 2);
277        assert_eq!(nals[0], &[0xAA, 0xBB]);
278        assert_eq!(nals[1], &[0xCC]);
279    }
280
281    #[test]
282    fn h264_strips_sps_pps_keeps_slice() {
283        // SPS (type 7), PPS (type 8), IDR slice (type 5)
284        let sps = [0x67u8, 0x42, 0x00, 0x1e, 0xAA];
285        let pps = [0x68u8, 0xCE, 0x3C];
286        let idr = [0x65u8, 0x88, 0x11, 0x22];
287        let mut frame = sc4(&sps);
288        frame.extend(sc4(&pps));
289        frame.extend(sc4(&idr));
290        let mut w = NalSampleWriter::new(NalMuxCodec::H264);
291        let samples = w.push_packet(&frame);
292        assert_eq!(samples.len(), 1, "no AUD → one access unit");
293        assert!(samples[0].is_keyframe, "contains an IDR slice");
294        // captured param sets (a 4-byte next start code may add a harmless
295        // trailing 0x00, so check the param set is a prefix of what was captured)
296        assert_eq!(w.sps.len(), 1);
297        assert!(w.sps[0].starts_with(&sps));
298        assert!(w.pps[0].starts_with(&pps));
299        assert!(w.has_param_sets());
300        // sample = length-prefixed IDR (the last NAL, no trailing start code → exact)
301        let mut expect = (idr.len() as u32).to_be_bytes().to_vec();
302        expect.extend_from_slice(&idr);
303        assert_eq!(samples[0].data, expect);
304    }
305
306    #[test]
307    fn splits_multi_au_packet_by_aud() {
308        // A packet with two AUDs (type 9) → two access-unit samples.
309        let aud = [0x09u8, 0x10];
310        let idr = [0x65u8, 0x11];
311        let p = [0x41u8, 0x22];
312        let mut frame = sc4(&aud);
313        frame.extend(sc4(&idr)); // AU 1: AUD + IDR
314        frame.extend(sc4(&aud));
315        frame.extend(sc4(&p)); // AU 2: AUD + P-slice
316        let mut w = NalSampleWriter::new(NalMuxCodec::H264);
317        let samples = w.push_packet(&frame);
318        assert_eq!(samples.len(), 2, "two AUDs → two samples");
319        assert!(samples[0].is_keyframe, "AU1 has the IDR");
320        assert!(!samples[1].is_keyframe, "AU2 is a P-frame");
321    }
322
323    #[test]
324    fn inline_mode_keeps_param_sets_in_sample() {
325        // Multi-GPU stitch: each access unit must self-describe with its own
326        // SPS/PPS (avc3/hev1), so a chunk decodes with its own parameter sets.
327        let sps = [0x67u8, 0x42, 0x00, 0x1e, 0xAA];
328        let pps = [0x68u8, 0xCE, 0x3C];
329        let idr = [0x65u8, 0x88, 0x11, 0x22];
330        let mut frame = sc4(&sps);
331        frame.extend(sc4(&pps));
332        frame.extend(sc4(&idr));
333
334        let mut w = NalSampleWriter::new_inline(NalMuxCodec::H264);
335        let inline = w.push_packet(&frame);
336        assert_eq!(inline.len(), 1);
337        assert!(inline[0].is_keyframe);
338        // Config box still records the first SPS/PPS as a default hint.
339        assert_eq!(w.sps.len(), 1);
340        assert!(w.sps[0].starts_with(&sps));
341        assert_eq!(w.pps.len(), 1);
342
343        // Out-of-band mode strips the params, so its sample is smaller.
344        let mut w2 = NalSampleWriter::new(NalMuxCodec::H264);
345        let oob = w2.push_packet(&frame);
346        assert!(
347            inline[0].data.len() > oob[0].data.len(),
348            "inline sample (SPS+PPS+IDR) must be larger than the stripped one ({} vs {})",
349            inline[0].data.len(),
350            oob[0].data.len()
351        );
352        // The inline sample begins with the length-prefixed SPS bytes.
353        assert_eq!(&inline[0].data[4..4 + sps.len()], &sps);
354    }
355
356    #[test]
357    fn h265_splits_multi_picture_packet_without_aud() {
358        // QSV H.265 emits no AUD: split on VCL slices with first_slice flag set.
359        let idr = [0x26u8, 0x01, 0xA0]; // type 19 (IDR), first_slice_segment=1
360        let trail = [0x02u8, 0x01, 0xA0]; // type 1 (TRAIL_R), first_slice_segment=1
361        let mut frame = sc4(&idr);
362        frame.extend(sc4(&trail));
363        let mut w = NalSampleWriter::new(NalMuxCodec::H265);
364        let samples = w.push_packet(&frame);
365        assert_eq!(samples.len(), 2, "two first-slice VCL NALs → two access units");
366        assert!(samples[0].is_keyframe);
367        assert!(!samples[1].is_keyframe);
368    }
369
370    #[test]
371    fn h265_captures_vps_sps_pps() {
372        let vps = [0x40u8, 0x01, 0x0c]; // type 32
373        let sps = [0x42u8, 0x01, 0x01]; // type 33
374        let pps = [0x44u8, 0x01, 0xc1]; // type 34
375        let slice = [0x26u8, 0x01, 0xaf]; // type 19 (IDR_W_RADL)
376        let mut frame = sc4(&vps);
377        frame.extend(sc4(&sps));
378        frame.extend(sc4(&pps));
379        frame.extend(sc4(&slice));
380        let mut w = NalSampleWriter::new(NalMuxCodec::H265);
381        let samples = w.push_packet(&frame);
382        assert_eq!(samples.len(), 1);
383        assert!(samples[0].is_keyframe, "type 19 is an IRAP/IDR");
384        assert!(w.vps[0].starts_with(&vps));
385        assert!(w.sps[0].starts_with(&sps));
386        assert!(w.pps[0].starts_with(&pps));
387        assert!(w.has_param_sets());
388        let mut expect = (slice.len() as u32).to_be_bytes().to_vec();
389        expect.extend_from_slice(&slice);
390        assert_eq!(samples[0].data, expect);
391    }
392
393    #[test]
394    fn preserves_slice_trailing_zero_bytes() {
395        // A slice NAL whose RBSP legitimately ends in zero bytes (cabac_zero_words)
396        // must NOT be truncated — that corrupts the slice and breaks decode.
397        let slice = [0x65u8, 0x88, 0x00, 0x00, 0x00];
398        let next = [0x41u8, 0x9a]; // a following P-slice
399        let mut frame = sc4(&slice);
400        frame.extend(sc4(&next));
401        let nals = split_annexb_nals(&frame);
402        assert_eq!(nals.len(), 2);
403        // The slice's own bytes (incl. its trailing zeros) are never eaten; a
404        // 4-byte next start code may leave one harmless extra trailing 0x00.
405        assert!(nals[0].starts_with(&slice), "slice trailing zeros must survive: {:?}", nals[0]);
406        assert!(nals[1].starts_with(&next));
407        // 3-byte next start code: the slice is preserved exactly.
408        let mut f2 = sc4(&slice);
409        f2.extend_from_slice(&[0, 0, 1]);
410        f2.extend_from_slice(&next);
411        let n2 = split_annexb_nals(&f2);
412        assert_eq!(n2[0], &slice, "trailing zeros kept exactly with a 3-byte next start code");
413    }
414
415    #[test]
416    fn dedups_repeated_param_sets() {
417        let sps = [0x67u8, 0x42, 0x00, 0x1e];
418        let pps = [0x68u8, 0xCE, 0x3C];
419        let idr = [0x65u8, 0x88];
420        let mut w = NalSampleWriter::new(NalMuxCodec::H264);
421        // two frames each repeating SPS/PPS (HW encoders often do this)
422        for _ in 0..2 {
423            let mut f = sc4(&sps);
424            f.extend(sc4(&pps));
425            f.extend(sc4(&idr));
426            w.push_packet(&f);
427        }
428        assert_eq!(w.sps.len(), 1);
429        assert_eq!(w.pps.len(), 1);
430    }
431}