rtmp_rs/media/
h264.rs

1//! H.264/AVC parsing
2//!
3//! RTMP transports H.264 in AVCC format (length-prefixed NAL units).
4//!
5//! AVC Video Packet Structure:
6//! ```text
7//! +----------+----------+-----------------+
8//! |FrameType | CodecID  | AVCPacketType   | CompositionTime | Data
9//! | (4 bits) | (4 bits) | (1 byte)        | (3 bytes, SI24) |
10//! +----------+----------+-----------------+
11//! ```
12//!
13//! AVCPacketType:
14//! - 0: AVC sequence header (AVCDecoderConfigurationRecord)
15//! - 1: AVC NALU (one or more NALUs)
16//! - 2: AVC end of sequence
17//!
18//! AVCDecoderConfigurationRecord (sequence header):
19//! ```text
20//! configurationVersion (1) | AVCProfileIndication (1) | profile_compatibility (1)
21//! | AVCLevelIndication (1) | lengthSizeMinusOne (1, lower 2 bits)
22//! | numOfSPS (1, lower 5 bits) | { spsLength (2) | spsNALUnit }*
23//! | numOfPPS (1) | { ppsLength (2) | ppsNALUnit }*
24//! ```
25
26use bytes::{Buf, Bytes};
27
28use crate::error::{MediaError, Result};
29
30/// AVC packet type
31#[derive(Debug, Clone, Copy, PartialEq, Eq)]
32pub enum AvcPacketType {
33    /// Sequence header (AVCDecoderConfigurationRecord)
34    SequenceHeader = 0,
35    /// NAL units
36    Nalu = 1,
37    /// End of sequence
38    EndOfSequence = 2,
39}
40
41impl AvcPacketType {
42    pub fn from_byte(b: u8) -> Option<Self> {
43        match b {
44            0 => Some(AvcPacketType::SequenceHeader),
45            1 => Some(AvcPacketType::Nalu),
46            2 => Some(AvcPacketType::EndOfSequence),
47            _ => None,
48        }
49    }
50}
51
52/// NAL unit type
53#[derive(Debug, Clone, Copy, PartialEq, Eq)]
54pub enum NaluType {
55    /// Non-IDR slice
56    Slice = 1,
57    /// Slice data partition A
58    SlicePartA = 2,
59    /// Slice data partition B
60    SlicePartB = 3,
61    /// Slice data partition C
62    SlicePartC = 4,
63    /// IDR slice (keyframe)
64    Idr = 5,
65    /// Supplemental enhancement information
66    Sei = 6,
67    /// Sequence parameter set
68    Sps = 7,
69    /// Picture parameter set
70    Pps = 8,
71    /// Access unit delimiter
72    Aud = 9,
73    /// End of sequence
74    EndSeq = 10,
75    /// End of stream
76    EndStream = 11,
77    /// Filler data
78    Filler = 12,
79}
80
81impl NaluType {
82    pub fn from_byte(b: u8) -> Option<Self> {
83        match b & 0x1F {
84            1 => Some(NaluType::Slice),
85            2 => Some(NaluType::SlicePartA),
86            3 => Some(NaluType::SlicePartB),
87            4 => Some(NaluType::SlicePartC),
88            5 => Some(NaluType::Idr),
89            6 => Some(NaluType::Sei),
90            7 => Some(NaluType::Sps),
91            8 => Some(NaluType::Pps),
92            9 => Some(NaluType::Aud),
93            10 => Some(NaluType::EndSeq),
94            11 => Some(NaluType::EndStream),
95            12 => Some(NaluType::Filler),
96            _ => None,
97        }
98    }
99
100    pub fn is_keyframe(&self) -> bool {
101        matches!(self, NaluType::Idr)
102    }
103
104    pub fn is_parameter_set(&self) -> bool {
105        matches!(self, NaluType::Sps | NaluType::Pps)
106    }
107}
108
109/// Parsed H.264 data
110#[derive(Debug, Clone)]
111pub enum H264Data {
112    /// Sequence header with SPS/PPS
113    SequenceHeader(AvcConfig),
114
115    /// Video frame (one or more NAL units)
116    Frame {
117        /// Whether this is a keyframe (IDR)
118        keyframe: bool,
119        /// Composition time offset (for B-frames)
120        composition_time: i32,
121        /// NAL units in AVCC format (length-prefixed)
122        nalus: Bytes,
123    },
124
125    /// End of sequence marker
126    EndOfSequence,
127}
128
129/// AVC decoder configuration (from sequence header)
130#[derive(Debug, Clone)]
131pub struct AvcConfig {
132    /// AVC profile (66=Baseline, 77=Main, 100=High, etc.)
133    pub profile: u8,
134    /// Profile compatibility flags
135    pub compatibility: u8,
136    /// AVC level (e.g., 31 = 3.1)
137    pub level: u8,
138    /// NALU length size minus 1 (usually 3, meaning 4-byte lengths)
139    pub nalu_length_size: u8,
140    /// Sequence Parameter Sets
141    pub sps: Vec<Bytes>,
142    /// Picture Parameter Sets
143    pub pps: Vec<Bytes>,
144    /// Raw AVCDecoderConfigurationRecord bytes
145    pub raw: Bytes,
146}
147
148impl AvcConfig {
149    /// Parse from AVCDecoderConfigurationRecord
150    pub fn parse(data: Bytes) -> Result<Self> {
151        if data.len() < 7 {
152            return Err(MediaError::InvalidAvcPacket.into());
153        }
154
155        // Clone the raw data before consuming it
156        let raw = data.clone();
157        let mut data = data;
158
159        let version = data.get_u8();
160        if version != 1 {
161            return Err(MediaError::InvalidAvcPacket.into());
162        }
163
164        let profile = data.get_u8();
165        let compatibility = data.get_u8();
166        let level = data.get_u8();
167        let nalu_length_size = (data.get_u8() & 0x03) + 1;
168
169        // Parse SPS
170        let num_sps = (data.get_u8() & 0x1F) as usize;
171        let mut sps = Vec::with_capacity(num_sps);
172        for _ in 0..num_sps {
173            if data.len() < 2 {
174                return Err(MediaError::InvalidAvcPacket.into());
175            }
176            let sps_len = data.get_u16() as usize;
177            if data.len() < sps_len {
178                return Err(MediaError::InvalidAvcPacket.into());
179            }
180            sps.push(data.copy_to_bytes(sps_len));
181        }
182
183        // Parse PPS
184        if data.is_empty() {
185            return Err(MediaError::InvalidAvcPacket.into());
186        }
187        let num_pps = data.get_u8() as usize;
188        let mut pps = Vec::with_capacity(num_pps);
189        for _ in 0..num_pps {
190            if data.len() < 2 {
191                return Err(MediaError::InvalidAvcPacket.into());
192            }
193            let pps_len = data.get_u16() as usize;
194            if data.len() < pps_len {
195                return Err(MediaError::InvalidAvcPacket.into());
196            }
197            pps.push(data.copy_to_bytes(pps_len));
198        }
199
200        Ok(AvcConfig {
201            profile,
202            compatibility,
203            level,
204            nalu_length_size,
205            sps,
206            pps,
207            raw,
208        })
209    }
210
211    /// Get profile name
212    pub fn profile_name(&self) -> &'static str {
213        match self.profile {
214            66 => "Baseline",
215            77 => "Main",
216            88 => "Extended",
217            100 => "High",
218            110 => "High 10",
219            122 => "High 4:2:2",
220            244 => "High 4:4:4",
221            _ => "Unknown",
222        }
223    }
224
225    /// Get level as string (e.g., "3.1")
226    pub fn level_string(&self) -> String {
227        format!("{}.{}", self.level / 10, self.level % 10)
228    }
229}
230
231impl H264Data {
232    /// Parse from RTMP video data (after frame type and codec ID bytes)
233    pub fn parse(mut data: Bytes) -> Result<Self> {
234        if data.len() < 4 {
235            return Err(MediaError::InvalidAvcPacket.into());
236        }
237
238        let packet_type = data.get_u8();
239
240        // Composition time (signed 24-bit)
241        let ct0 = data.get_u8() as i32;
242        let ct1 = data.get_u8() as i32;
243        let ct2 = data.get_u8() as i32;
244        let composition_time = (ct0 << 16) | (ct1 << 8) | ct2;
245        // Sign extend from 24 bits
246        let composition_time = if composition_time & 0x800000 != 0 {
247            composition_time | !0xFFFFFF
248        } else {
249            composition_time
250        };
251
252        match AvcPacketType::from_byte(packet_type) {
253            Some(AvcPacketType::SequenceHeader) => {
254                let config = AvcConfig::parse(data)?;
255                Ok(H264Data::SequenceHeader(config))
256            }
257            Some(AvcPacketType::Nalu) => {
258                // Check for IDR in the NAL units
259                let keyframe = Self::contains_idr(&data);
260                Ok(H264Data::Frame {
261                    keyframe,
262                    composition_time,
263                    nalus: data,
264                })
265            }
266            Some(AvcPacketType::EndOfSequence) => Ok(H264Data::EndOfSequence),
267            None => Err(MediaError::InvalidAvcPacket.into()),
268        }
269    }
270
271    /// Check if NAL units contain an IDR frame
272    fn contains_idr(data: &Bytes) -> bool {
273        let mut offset = 0;
274        while offset + 4 < data.len() {
275            // Read NALU length (assume 4 bytes, most common)
276            let len = u32::from_be_bytes([
277                data[offset],
278                data[offset + 1],
279                data[offset + 2],
280                data[offset + 3],
281            ]) as usize;
282            offset += 4;
283
284            if offset >= data.len() {
285                break;
286            }
287
288            // Check NAL unit type
289            let nalu_type = NaluType::from_byte(data[offset]);
290            if nalu_type == Some(NaluType::Idr) {
291                return true;
292            }
293
294            offset += len;
295        }
296        false
297    }
298
299    /// Check if this is a keyframe
300    pub fn is_keyframe(&self) -> bool {
301        match self {
302            H264Data::SequenceHeader(_) => true, // Sequence headers are keyframe-associated
303            H264Data::Frame { keyframe, .. } => *keyframe,
304            H264Data::EndOfSequence => false,
305        }
306    }
307
308    /// Check if this is a sequence header
309    pub fn is_sequence_header(&self) -> bool {
310        matches!(self, H264Data::SequenceHeader(_))
311    }
312}
313
314/// Iterator over NAL units in AVCC format
315pub struct NaluIterator<'a> {
316    data: &'a [u8],
317    offset: usize,
318    nalu_length_size: usize,
319}
320
321impl<'a> NaluIterator<'a> {
322    pub fn new(data: &'a [u8], nalu_length_size: u8) -> Self {
323        Self {
324            data,
325            offset: 0,
326            nalu_length_size: nalu_length_size as usize,
327        }
328    }
329}
330
331impl<'a> Iterator for NaluIterator<'a> {
332    type Item = &'a [u8];
333
334    fn next(&mut self) -> Option<Self::Item> {
335        if self.offset + self.nalu_length_size > self.data.len() {
336            return None;
337        }
338
339        // Read length (big-endian)
340        let mut len: usize = 0;
341        for i in 0..self.nalu_length_size {
342            len = (len << 8) | (self.data[self.offset + i] as usize);
343        }
344        self.offset += self.nalu_length_size;
345
346        if self.offset + len > self.data.len() {
347            return None;
348        }
349
350        let nalu = &self.data[self.offset..self.offset + len];
351        self.offset += len;
352        Some(nalu)
353    }
354}
355
356#[cfg(test)]
357mod tests {
358    use super::*;
359
360    #[test]
361    fn test_nalu_type() {
362        assert_eq!(NaluType::from_byte(0x65), Some(NaluType::Idr));
363        assert_eq!(NaluType::from_byte(0x67), Some(NaluType::Sps));
364        assert_eq!(NaluType::from_byte(0x68), Some(NaluType::Pps));
365        assert_eq!(NaluType::from_byte(0x41), Some(NaluType::Slice));
366    }
367
368    #[test]
369    fn test_avc_config_parse() {
370        // Minimal valid AVCDecoderConfigurationRecord
371        let data = Bytes::from_static(&[
372            0x01, // version
373            0x64, // profile (High)
374            0x00, // compatibility
375            0x1F, // level 3.1
376            0xFF, // nalu length size = 4
377            0xE1, // 1 SPS
378            0x00, 0x04, // SPS length
379            0x67, 0x64, 0x00, 0x1F, // SPS data
380            0x01, // 1 PPS
381            0x00, 0x03, // PPS length
382            0x68, 0xEF, 0x38, // PPS data
383        ]);
384
385        let config = AvcConfig::parse(data.clone()).unwrap();
386        assert_eq!(config.profile, 100);
387        assert_eq!(config.level, 31);
388        assert_eq!(config.nalu_length_size, 4);
389        assert_eq!(config.sps.len(), 1);
390        assert_eq!(config.pps.len(), 1);
391        assert_eq!(config.profile_name(), "High");
392        assert_eq!(config.level_string(), "3.1");
393        // Verify raw field contains the original bytes
394        assert_eq!(config.raw.len(), data.len());
395        assert_eq!(config.raw, data);
396    }
397
398    #[test]
399    fn test_avc_packet_type() {
400        assert_eq!(
401            AvcPacketType::from_byte(0),
402            Some(AvcPacketType::SequenceHeader)
403        );
404        assert_eq!(AvcPacketType::from_byte(1), Some(AvcPacketType::Nalu));
405        assert_eq!(
406            AvcPacketType::from_byte(2),
407            Some(AvcPacketType::EndOfSequence)
408        );
409        assert_eq!(AvcPacketType::from_byte(3), None);
410        assert_eq!(AvcPacketType::from_byte(255), None);
411    }
412
413    #[test]
414    fn test_nalu_type_parsing() {
415        // Test all documented NALU types
416        assert_eq!(NaluType::from_byte(0x01), Some(NaluType::Slice));
417        assert_eq!(NaluType::from_byte(0x02), Some(NaluType::SlicePartA));
418        assert_eq!(NaluType::from_byte(0x03), Some(NaluType::SlicePartB));
419        assert_eq!(NaluType::from_byte(0x04), Some(NaluType::SlicePartC));
420        assert_eq!(NaluType::from_byte(0x05), Some(NaluType::Idr));
421        assert_eq!(NaluType::from_byte(0x06), Some(NaluType::Sei));
422        assert_eq!(NaluType::from_byte(0x07), Some(NaluType::Sps));
423        assert_eq!(NaluType::from_byte(0x08), Some(NaluType::Pps));
424        assert_eq!(NaluType::from_byte(0x09), Some(NaluType::Aud));
425        assert_eq!(NaluType::from_byte(0x0A), Some(NaluType::EndSeq));
426        assert_eq!(NaluType::from_byte(0x0B), Some(NaluType::EndStream));
427        assert_eq!(NaluType::from_byte(0x0C), Some(NaluType::Filler));
428
429        // Test with forbidden_zero_bit and nal_ref_idc bits set
430        assert_eq!(NaluType::from_byte(0x65), Some(NaluType::Idr)); // 0x65 & 0x1F = 5
431        assert_eq!(NaluType::from_byte(0x67), Some(NaluType::Sps)); // 0x67 & 0x1F = 7
432    }
433
434    #[test]
435    fn test_nalu_type_is_keyframe() {
436        assert!(NaluType::Idr.is_keyframe());
437        assert!(!NaluType::Slice.is_keyframe());
438        assert!(!NaluType::Sps.is_keyframe());
439        assert!(!NaluType::Pps.is_keyframe());
440    }
441
442    #[test]
443    fn test_nalu_type_is_parameter_set() {
444        assert!(NaluType::Sps.is_parameter_set());
445        assert!(NaluType::Pps.is_parameter_set());
446        assert!(!NaluType::Idr.is_parameter_set());
447        assert!(!NaluType::Slice.is_parameter_set());
448    }
449
450    #[test]
451    fn test_h264_data_sequence_header() {
452        // Simulate parsing a sequence header packet
453        let data = Bytes::from_static(&[
454            0x00, // AVC sequence header
455            0x00, 0x00, 0x00, // composition time (0)
456            // AVCDecoderConfigurationRecord
457            0x01, 0x64, 0x00, 0x1F, 0xFF, // version, profile, compat, level, length-1
458            0xE1, // 1 SPS
459            0x00, 0x04, 0x67, 0x64, 0x00, 0x1F, // SPS
460            0x01, // 1 PPS
461            0x00, 0x03, 0x68, 0xEF, 0x38, // PPS
462        ]);
463
464        let h264 = H264Data::parse(data).unwrap();
465        assert!(h264.is_sequence_header());
466        assert!(h264.is_keyframe());
467
468        if let H264Data::SequenceHeader(config) = h264 {
469            assert_eq!(config.profile, 100);
470        } else {
471            panic!("Expected SequenceHeader");
472        }
473    }
474
475    #[test]
476    fn test_h264_data_end_of_sequence() {
477        let data = Bytes::from_static(&[
478            0x02, // End of sequence
479            0x00, 0x00, 0x00, // composition time
480        ]);
481
482        let h264 = H264Data::parse(data).unwrap();
483        assert!(matches!(h264, H264Data::EndOfSequence));
484        assert!(!h264.is_keyframe());
485        assert!(!h264.is_sequence_header());
486    }
487
488    #[test]
489    fn test_h264_data_nalu_keyframe() {
490        // Create NALU data with IDR frame
491        let data = Bytes::from_static(&[
492            0x01, // AVC NALU
493            0x00, 0x00, 0x00, // composition time (0)
494            // NALU with length prefix
495            0x00, 0x00, 0x00, 0x05, // length = 5
496            0x65, 0x88, 0x84, 0x00, 0x00, // IDR NALU (type 5)
497        ]);
498
499        let h264 = H264Data::parse(data).unwrap();
500        assert!(h264.is_keyframe());
501        assert!(!h264.is_sequence_header());
502
503        if let H264Data::Frame {
504            keyframe,
505            composition_time,
506            ..
507        } = h264
508        {
509            assert!(keyframe);
510            assert_eq!(composition_time, 0);
511        } else {
512            panic!("Expected Frame");
513        }
514    }
515
516    #[test]
517    fn test_h264_data_nalu_p_frame() {
518        // Create NALU data with non-IDR slice
519        let data = Bytes::from_static(&[
520            0x01, // AVC NALU
521            0x00, 0x00, 0x00, // composition time
522            // NALU with length prefix
523            0x00, 0x00, 0x00, 0x05, // length = 5
524            0x41, 0x9A, 0x00, 0x00, 0x00, // Non-IDR slice (type 1)
525        ]);
526
527        let h264 = H264Data::parse(data).unwrap();
528        assert!(!h264.is_keyframe());
529
530        if let H264Data::Frame { keyframe, .. } = h264 {
531            assert!(!keyframe);
532        }
533    }
534
535    #[test]
536    fn test_h264_composition_time_positive() {
537        let data = Bytes::from_static(&[
538            0x01, // AVC NALU
539            0x00, 0x01, 0x00, // composition time = 256
540            0x00, 0x00, 0x00, 0x01, // length
541            0x41, // Non-IDR
542        ]);
543
544        let h264 = H264Data::parse(data).unwrap();
545        if let H264Data::Frame {
546            composition_time, ..
547        } = h264
548        {
549            assert_eq!(composition_time, 256);
550        }
551    }
552
553    #[test]
554    fn test_h264_composition_time_negative() {
555        // Negative composition time (sign-extended from 24 bits)
556        let data = Bytes::from_static(&[
557            0x01, // AVC NALU
558            0xFF, 0xFF, 0x00, // composition time = -256 (as signed 24-bit)
559            0x00, 0x00, 0x00, 0x01, // length
560            0x41, // Non-IDR
561        ]);
562
563        let h264 = H264Data::parse(data).unwrap();
564        if let H264Data::Frame {
565            composition_time, ..
566        } = h264
567        {
568            assert_eq!(composition_time, -256);
569        }
570    }
571
572    #[test]
573    fn test_h264_data_invalid_packet_type() {
574        let data = Bytes::from_static(&[
575            0x03, // Invalid packet type
576            0x00, 0x00, 0x00,
577        ]);
578
579        let result = H264Data::parse(data);
580        assert!(result.is_err());
581    }
582
583    #[test]
584    fn test_h264_data_too_short() {
585        let data = Bytes::from_static(&[0x00, 0x00]); // Less than 4 bytes
586        let result = H264Data::parse(data);
587        assert!(result.is_err());
588    }
589
590    #[test]
591    fn test_avc_config_profile_names() {
592        // Test various profile names
593        let profiles = [
594            (66, "Baseline"),
595            (77, "Main"),
596            (88, "Extended"),
597            (100, "High"),
598            (110, "High 10"),
599            (122, "High 4:2:2"),
600            (244, "High 4:4:4"),
601            (99, "Unknown"),
602        ];
603
604        for (profile, expected_name) in profiles {
605            let config = AvcConfig {
606                profile,
607                compatibility: 0,
608                level: 31,
609                nalu_length_size: 4,
610                sps: vec![],
611                pps: vec![],
612                raw: Bytes::new(),
613            };
614            assert_eq!(config.profile_name(), expected_name);
615        }
616    }
617
618    #[test]
619    fn test_avc_config_level_string() {
620        let config = AvcConfig {
621            profile: 100,
622            compatibility: 0,
623            level: 41, // Level 4.1
624            nalu_length_size: 4,
625            sps: vec![],
626            pps: vec![],
627            raw: Bytes::new(),
628        };
629        assert_eq!(config.level_string(), "4.1");
630
631        let config2 = AvcConfig {
632            profile: 100,
633            compatibility: 0,
634            level: 52, // Level 5.2
635            nalu_length_size: 4,
636            sps: vec![],
637            pps: vec![],
638            raw: Bytes::new(),
639        };
640        assert_eq!(config2.level_string(), "5.2");
641    }
642
643    #[test]
644    fn test_avc_config_invalid_version() {
645        let data = Bytes::from_static(&[
646            0x02, // Invalid version (should be 1)
647            0x64, 0x00, 0x1F, 0xFF, 0xE1, 0x00, 0x04, 0x67, 0x64, 0x00, 0x1F, 0x01, 0x00, 0x03,
648            0x68, 0xEF, 0x38,
649        ]);
650
651        let result = AvcConfig::parse(data);
652        assert!(result.is_err());
653    }
654
655    #[test]
656    fn test_avc_config_too_short() {
657        let data = Bytes::from_static(&[0x01, 0x64, 0x00]); // Less than 7 bytes
658        let result = AvcConfig::parse(data);
659        assert!(result.is_err());
660    }
661
662    #[test]
663    fn test_nalu_iterator() {
664        // Create AVCC-format data with multiple NALUs
665        let data: &[u8] = &[
666            0x00, 0x00, 0x00, 0x03, // length = 3
667            0x67, 0x64, 0x00, // SPS NALU
668            0x00, 0x00, 0x00, 0x02, // length = 2
669            0x68, 0xEF, // PPS NALU
670        ];
671
672        let mut iter = NaluIterator::new(data, 4);
673
674        let nalu1 = iter.next().unwrap();
675        assert_eq!(nalu1.len(), 3);
676        assert_eq!(NaluType::from_byte(nalu1[0]), Some(NaluType::Sps));
677
678        let nalu2 = iter.next().unwrap();
679        assert_eq!(nalu2.len(), 2);
680        assert_eq!(NaluType::from_byte(nalu2[0]), Some(NaluType::Pps));
681
682        assert!(iter.next().is_none());
683    }
684
685    #[test]
686    fn test_nalu_iterator_different_length_sizes() {
687        // Test with 2-byte length prefix
688        let data: &[u8] = &[
689            0x00, 0x02, // length = 2
690            0x65, 0x88, // IDR NALU
691        ];
692
693        let mut iter = NaluIterator::new(data, 2);
694        let nalu = iter.next().unwrap();
695        assert_eq!(nalu.len(), 2);
696    }
697
698    #[test]
699    fn test_nalu_iterator_empty() {
700        let data: &[u8] = &[];
701        let mut iter = NaluIterator::new(data, 4);
702        assert!(iter.next().is_none());
703    }
704
705    #[test]
706    fn test_nalu_iterator_truncated() {
707        // Length says 10 bytes but only 3 available
708        let data: &[u8] = &[
709            0x00, 0x00, 0x00, 0x0A, // length = 10
710            0x67, 0x64, 0x00, // Only 3 bytes
711        ];
712
713        let mut iter = NaluIterator::new(data, 4);
714        assert!(iter.next().is_none()); // Should return None for truncated data
715    }
716}