rtmp_rs/media/
h264.rs

1//! H.264/AVC parsing
2//!
3//! RTMP transports H.264 in AVCC format (length-prefixed NAL units).
4//!
5//! AVC Video Packet Structure:
6//! ```text
7//! +----------+----------+-----------------+
8//! |FrameType | CodecID  | AVCPacketType   | CompositionTime | Data
9//! | (4 bits) | (4 bits) | (1 byte)        | (3 bytes, SI24) |
10//! +----------+----------+-----------------+
11//! ```
12//!
13//! AVCPacketType:
14//! - 0: AVC sequence header (AVCDecoderConfigurationRecord)
15//! - 1: AVC NALU (one or more NALUs)
16//! - 2: AVC end of sequence
17//!
18//! AVCDecoderConfigurationRecord (sequence header):
19//! ```text
20//! configurationVersion (1) | AVCProfileIndication (1) | profile_compatibility (1)
21//! | AVCLevelIndication (1) | lengthSizeMinusOne (1, lower 2 bits)
22//! | numOfSPS (1, lower 5 bits) | { spsLength (2) | spsNALUnit }*
23//! | numOfPPS (1) | { ppsLength (2) | ppsNALUnit }*
24//! ```
25
26use bytes::{Buf, Bytes};
27
28use crate::error::{MediaError, Result};
29
30/// AVC packet type
31#[derive(Debug, Clone, Copy, PartialEq, Eq)]
32pub enum AvcPacketType {
33    /// Sequence header (AVCDecoderConfigurationRecord)
34    SequenceHeader = 0,
35    /// NAL units
36    Nalu = 1,
37    /// End of sequence
38    EndOfSequence = 2,
39}
40
41impl AvcPacketType {
42    pub fn from_byte(b: u8) -> Option<Self> {
43        match b {
44            0 => Some(AvcPacketType::SequenceHeader),
45            1 => Some(AvcPacketType::Nalu),
46            2 => Some(AvcPacketType::EndOfSequence),
47            _ => None,
48        }
49    }
50}
51
52/// NAL unit type
53#[derive(Debug, Clone, Copy, PartialEq, Eq)]
54pub enum NaluType {
55    /// Non-IDR slice
56    Slice = 1,
57    /// Slice data partition A
58    SlicePartA = 2,
59    /// Slice data partition B
60    SlicePartB = 3,
61    /// Slice data partition C
62    SlicePartC = 4,
63    /// IDR slice (keyframe)
64    Idr = 5,
65    /// Supplemental enhancement information
66    Sei = 6,
67    /// Sequence parameter set
68    Sps = 7,
69    /// Picture parameter set
70    Pps = 8,
71    /// Access unit delimiter
72    Aud = 9,
73    /// End of sequence
74    EndSeq = 10,
75    /// End of stream
76    EndStream = 11,
77    /// Filler data
78    Filler = 12,
79}
80
81impl NaluType {
82    pub fn from_byte(b: u8) -> Option<Self> {
83        match b & 0x1F {
84            1 => Some(NaluType::Slice),
85            2 => Some(NaluType::SlicePartA),
86            3 => Some(NaluType::SlicePartB),
87            4 => Some(NaluType::SlicePartC),
88            5 => Some(NaluType::Idr),
89            6 => Some(NaluType::Sei),
90            7 => Some(NaluType::Sps),
91            8 => Some(NaluType::Pps),
92            9 => Some(NaluType::Aud),
93            10 => Some(NaluType::EndSeq),
94            11 => Some(NaluType::EndStream),
95            12 => Some(NaluType::Filler),
96            _ => None,
97        }
98    }
99
100    pub fn is_keyframe(&self) -> bool {
101        matches!(self, NaluType::Idr)
102    }
103
104    pub fn is_parameter_set(&self) -> bool {
105        matches!(self, NaluType::Sps | NaluType::Pps)
106    }
107}
108
109/// Parsed H.264 data
110#[derive(Debug, Clone)]
111pub enum H264Data {
112    /// Sequence header with SPS/PPS
113    SequenceHeader(AvcConfig),
114
115    /// Video frame (one or more NAL units)
116    Frame {
117        /// Whether this is a keyframe (IDR)
118        keyframe: bool,
119        /// Composition time offset (for B-frames)
120        composition_time: i32,
121        /// NAL units in AVCC format (length-prefixed)
122        nalus: Bytes,
123    },
124
125    /// End of sequence marker
126    EndOfSequence,
127}
128
129/// AVC decoder configuration (from sequence header)
130#[derive(Debug, Clone)]
131pub struct AvcConfig {
132    /// AVC profile (66=Baseline, 77=Main, 100=High, etc.)
133    pub profile: u8,
134    /// Profile compatibility flags
135    pub compatibility: u8,
136    /// AVC level (e.g., 31 = 3.1)
137    pub level: u8,
138    /// NALU length size minus 1 (usually 3, meaning 4-byte lengths)
139    pub nalu_length_size: u8,
140    /// Sequence Parameter Sets
141    pub sps: Vec<Bytes>,
142    /// Picture Parameter Sets
143    pub pps: Vec<Bytes>,
144}
145
146impl AvcConfig {
147    /// Parse from AVCDecoderConfigurationRecord
148    pub fn parse(mut data: Bytes) -> Result<Self> {
149        if data.len() < 7 {
150            return Err(MediaError::InvalidAvcPacket.into());
151        }
152
153        let version = data.get_u8();
154        if version != 1 {
155            return Err(MediaError::InvalidAvcPacket.into());
156        }
157
158        let profile = data.get_u8();
159        let compatibility = data.get_u8();
160        let level = data.get_u8();
161        let nalu_length_size = (data.get_u8() & 0x03) + 1;
162
163        // Parse SPS
164        let num_sps = (data.get_u8() & 0x1F) as usize;
165        let mut sps = Vec::with_capacity(num_sps);
166        for _ in 0..num_sps {
167            if data.len() < 2 {
168                return Err(MediaError::InvalidAvcPacket.into());
169            }
170            let sps_len = data.get_u16() as usize;
171            if data.len() < sps_len {
172                return Err(MediaError::InvalidAvcPacket.into());
173            }
174            sps.push(data.copy_to_bytes(sps_len));
175        }
176
177        // Parse PPS
178        if data.is_empty() {
179            return Err(MediaError::InvalidAvcPacket.into());
180        }
181        let num_pps = data.get_u8() as usize;
182        let mut pps = Vec::with_capacity(num_pps);
183        for _ in 0..num_pps {
184            if data.len() < 2 {
185                return Err(MediaError::InvalidAvcPacket.into());
186            }
187            let pps_len = data.get_u16() as usize;
188            if data.len() < pps_len {
189                return Err(MediaError::InvalidAvcPacket.into());
190            }
191            pps.push(data.copy_to_bytes(pps_len));
192        }
193
194        Ok(AvcConfig {
195            profile,
196            compatibility,
197            level,
198            nalu_length_size,
199            sps,
200            pps,
201        })
202    }
203
204    /// Get profile name
205    pub fn profile_name(&self) -> &'static str {
206        match self.profile {
207            66 => "Baseline",
208            77 => "Main",
209            88 => "Extended",
210            100 => "High",
211            110 => "High 10",
212            122 => "High 4:2:2",
213            244 => "High 4:4:4",
214            _ => "Unknown",
215        }
216    }
217
218    /// Get level as string (e.g., "3.1")
219    pub fn level_string(&self) -> String {
220        format!("{}.{}", self.level / 10, self.level % 10)
221    }
222}
223
224impl H264Data {
225    /// Parse from RTMP video data (after frame type and codec ID bytes)
226    pub fn parse(mut data: Bytes) -> Result<Self> {
227        if data.len() < 4 {
228            return Err(MediaError::InvalidAvcPacket.into());
229        }
230
231        let packet_type = data.get_u8();
232
233        // Composition time (signed 24-bit)
234        let ct0 = data.get_u8() as i32;
235        let ct1 = data.get_u8() as i32;
236        let ct2 = data.get_u8() as i32;
237        let composition_time = (ct0 << 16) | (ct1 << 8) | ct2;
238        // Sign extend from 24 bits
239        let composition_time = if composition_time & 0x800000 != 0 {
240            composition_time | !0xFFFFFF
241        } else {
242            composition_time
243        };
244
245        match AvcPacketType::from_byte(packet_type) {
246            Some(AvcPacketType::SequenceHeader) => {
247                let config = AvcConfig::parse(data)?;
248                Ok(H264Data::SequenceHeader(config))
249            }
250            Some(AvcPacketType::Nalu) => {
251                // Check for IDR in the NAL units
252                let keyframe = Self::contains_idr(&data);
253                Ok(H264Data::Frame {
254                    keyframe,
255                    composition_time,
256                    nalus: data,
257                })
258            }
259            Some(AvcPacketType::EndOfSequence) => Ok(H264Data::EndOfSequence),
260            None => Err(MediaError::InvalidAvcPacket.into()),
261        }
262    }
263
264    /// Check if NAL units contain an IDR frame
265    fn contains_idr(data: &Bytes) -> bool {
266        let mut offset = 0;
267        while offset + 4 < data.len() {
268            // Read NALU length (assume 4 bytes, most common)
269            let len = u32::from_be_bytes([
270                data[offset],
271                data[offset + 1],
272                data[offset + 2],
273                data[offset + 3],
274            ]) as usize;
275            offset += 4;
276
277            if offset >= data.len() {
278                break;
279            }
280
281            // Check NAL unit type
282            let nalu_type = NaluType::from_byte(data[offset]);
283            if nalu_type == Some(NaluType::Idr) {
284                return true;
285            }
286
287            offset += len;
288        }
289        false
290    }
291
292    /// Check if this is a keyframe
293    pub fn is_keyframe(&self) -> bool {
294        match self {
295            H264Data::SequenceHeader(_) => true, // Sequence headers are keyframe-associated
296            H264Data::Frame { keyframe, .. } => *keyframe,
297            H264Data::EndOfSequence => false,
298        }
299    }
300
301    /// Check if this is a sequence header
302    pub fn is_sequence_header(&self) -> bool {
303        matches!(self, H264Data::SequenceHeader(_))
304    }
305}
306
307/// Iterator over NAL units in AVCC format
308pub struct NaluIterator<'a> {
309    data: &'a [u8],
310    offset: usize,
311    nalu_length_size: usize,
312}
313
314impl<'a> NaluIterator<'a> {
315    pub fn new(data: &'a [u8], nalu_length_size: u8) -> Self {
316        Self {
317            data,
318            offset: 0,
319            nalu_length_size: nalu_length_size as usize,
320        }
321    }
322}
323
324impl<'a> Iterator for NaluIterator<'a> {
325    type Item = &'a [u8];
326
327    fn next(&mut self) -> Option<Self::Item> {
328        if self.offset + self.nalu_length_size > self.data.len() {
329            return None;
330        }
331
332        // Read length (big-endian)
333        let mut len: usize = 0;
334        for i in 0..self.nalu_length_size {
335            len = (len << 8) | (self.data[self.offset + i] as usize);
336        }
337        self.offset += self.nalu_length_size;
338
339        if self.offset + len > self.data.len() {
340            return None;
341        }
342
343        let nalu = &self.data[self.offset..self.offset + len];
344        self.offset += len;
345        Some(nalu)
346    }
347}
348
349#[cfg(test)]
350mod tests {
351    use super::*;
352
353    #[test]
354    fn test_nalu_type() {
355        assert_eq!(NaluType::from_byte(0x65), Some(NaluType::Idr));
356        assert_eq!(NaluType::from_byte(0x67), Some(NaluType::Sps));
357        assert_eq!(NaluType::from_byte(0x68), Some(NaluType::Pps));
358        assert_eq!(NaluType::from_byte(0x41), Some(NaluType::Slice));
359    }
360
361    #[test]
362    fn test_avc_config_parse() {
363        // Minimal valid AVCDecoderConfigurationRecord
364        let data = Bytes::from_static(&[
365            0x01, // version
366            0x64, // profile (High)
367            0x00, // compatibility
368            0x1F, // level 3.1
369            0xFF, // nalu length size = 4
370            0xE1, // 1 SPS
371            0x00, 0x04, // SPS length
372            0x67, 0x64, 0x00, 0x1F, // SPS data
373            0x01, // 1 PPS
374            0x00, 0x03, // PPS length
375            0x68, 0xEF, 0x38, // PPS data
376        ]);
377
378        let config = AvcConfig::parse(data).unwrap();
379        assert_eq!(config.profile, 100);
380        assert_eq!(config.level, 31);
381        assert_eq!(config.nalu_length_size, 4);
382        assert_eq!(config.sps.len(), 1);
383        assert_eq!(config.pps.len(), 1);
384        assert_eq!(config.profile_name(), "High");
385        assert_eq!(config.level_string(), "3.1");
386    }
387
388    #[test]
389    fn test_avc_packet_type() {
390        assert_eq!(
391            AvcPacketType::from_byte(0),
392            Some(AvcPacketType::SequenceHeader)
393        );
394        assert_eq!(AvcPacketType::from_byte(1), Some(AvcPacketType::Nalu));
395        assert_eq!(
396            AvcPacketType::from_byte(2),
397            Some(AvcPacketType::EndOfSequence)
398        );
399        assert_eq!(AvcPacketType::from_byte(3), None);
400        assert_eq!(AvcPacketType::from_byte(255), None);
401    }
402
403    #[test]
404    fn test_nalu_type_parsing() {
405        // Test all documented NALU types
406        assert_eq!(NaluType::from_byte(0x01), Some(NaluType::Slice));
407        assert_eq!(NaluType::from_byte(0x02), Some(NaluType::SlicePartA));
408        assert_eq!(NaluType::from_byte(0x03), Some(NaluType::SlicePartB));
409        assert_eq!(NaluType::from_byte(0x04), Some(NaluType::SlicePartC));
410        assert_eq!(NaluType::from_byte(0x05), Some(NaluType::Idr));
411        assert_eq!(NaluType::from_byte(0x06), Some(NaluType::Sei));
412        assert_eq!(NaluType::from_byte(0x07), Some(NaluType::Sps));
413        assert_eq!(NaluType::from_byte(0x08), Some(NaluType::Pps));
414        assert_eq!(NaluType::from_byte(0x09), Some(NaluType::Aud));
415        assert_eq!(NaluType::from_byte(0x0A), Some(NaluType::EndSeq));
416        assert_eq!(NaluType::from_byte(0x0B), Some(NaluType::EndStream));
417        assert_eq!(NaluType::from_byte(0x0C), Some(NaluType::Filler));
418
419        // Test with forbidden_zero_bit and nal_ref_idc bits set
420        assert_eq!(NaluType::from_byte(0x65), Some(NaluType::Idr)); // 0x65 & 0x1F = 5
421        assert_eq!(NaluType::from_byte(0x67), Some(NaluType::Sps)); // 0x67 & 0x1F = 7
422    }
423
424    #[test]
425    fn test_nalu_type_is_keyframe() {
426        assert!(NaluType::Idr.is_keyframe());
427        assert!(!NaluType::Slice.is_keyframe());
428        assert!(!NaluType::Sps.is_keyframe());
429        assert!(!NaluType::Pps.is_keyframe());
430    }
431
432    #[test]
433    fn test_nalu_type_is_parameter_set() {
434        assert!(NaluType::Sps.is_parameter_set());
435        assert!(NaluType::Pps.is_parameter_set());
436        assert!(!NaluType::Idr.is_parameter_set());
437        assert!(!NaluType::Slice.is_parameter_set());
438    }
439
440    #[test]
441    fn test_h264_data_sequence_header() {
442        // Simulate parsing a sequence header packet
443        let data = Bytes::from_static(&[
444            0x00, // AVC sequence header
445            0x00, 0x00, 0x00, // composition time (0)
446            // AVCDecoderConfigurationRecord
447            0x01, 0x64, 0x00, 0x1F, 0xFF, // version, profile, compat, level, length-1
448            0xE1, // 1 SPS
449            0x00, 0x04, 0x67, 0x64, 0x00, 0x1F, // SPS
450            0x01, // 1 PPS
451            0x00, 0x03, 0x68, 0xEF, 0x38, // PPS
452        ]);
453
454        let h264 = H264Data::parse(data).unwrap();
455        assert!(h264.is_sequence_header());
456        assert!(h264.is_keyframe());
457
458        if let H264Data::SequenceHeader(config) = h264 {
459            assert_eq!(config.profile, 100);
460        } else {
461            panic!("Expected SequenceHeader");
462        }
463    }
464
465    #[test]
466    fn test_h264_data_end_of_sequence() {
467        let data = Bytes::from_static(&[
468            0x02, // End of sequence
469            0x00, 0x00, 0x00, // composition time
470        ]);
471
472        let h264 = H264Data::parse(data).unwrap();
473        assert!(matches!(h264, H264Data::EndOfSequence));
474        assert!(!h264.is_keyframe());
475        assert!(!h264.is_sequence_header());
476    }
477
478    #[test]
479    fn test_h264_data_nalu_keyframe() {
480        // Create NALU data with IDR frame
481        let data = Bytes::from_static(&[
482            0x01, // AVC NALU
483            0x00, 0x00, 0x00, // composition time (0)
484            // NALU with length prefix
485            0x00, 0x00, 0x00, 0x05, // length = 5
486            0x65, 0x88, 0x84, 0x00, 0x00, // IDR NALU (type 5)
487        ]);
488
489        let h264 = H264Data::parse(data).unwrap();
490        assert!(h264.is_keyframe());
491        assert!(!h264.is_sequence_header());
492
493        if let H264Data::Frame {
494            keyframe,
495            composition_time,
496            ..
497        } = h264
498        {
499            assert!(keyframe);
500            assert_eq!(composition_time, 0);
501        } else {
502            panic!("Expected Frame");
503        }
504    }
505
506    #[test]
507    fn test_h264_data_nalu_p_frame() {
508        // Create NALU data with non-IDR slice
509        let data = Bytes::from_static(&[
510            0x01, // AVC NALU
511            0x00, 0x00, 0x00, // composition time
512            // NALU with length prefix
513            0x00, 0x00, 0x00, 0x05, // length = 5
514            0x41, 0x9A, 0x00, 0x00, 0x00, // Non-IDR slice (type 1)
515        ]);
516
517        let h264 = H264Data::parse(data).unwrap();
518        assert!(!h264.is_keyframe());
519
520        if let H264Data::Frame { keyframe, .. } = h264 {
521            assert!(!keyframe);
522        }
523    }
524
525    #[test]
526    fn test_h264_composition_time_positive() {
527        let data = Bytes::from_static(&[
528            0x01, // AVC NALU
529            0x00, 0x01, 0x00, // composition time = 256
530            0x00, 0x00, 0x00, 0x01, // length
531            0x41, // Non-IDR
532        ]);
533
534        let h264 = H264Data::parse(data).unwrap();
535        if let H264Data::Frame {
536            composition_time, ..
537        } = h264
538        {
539            assert_eq!(composition_time, 256);
540        }
541    }
542
543    #[test]
544    fn test_h264_composition_time_negative() {
545        // Negative composition time (sign-extended from 24 bits)
546        let data = Bytes::from_static(&[
547            0x01, // AVC NALU
548            0xFF, 0xFF, 0x00, // composition time = -256 (as signed 24-bit)
549            0x00, 0x00, 0x00, 0x01, // length
550            0x41, // Non-IDR
551        ]);
552
553        let h264 = H264Data::parse(data).unwrap();
554        if let H264Data::Frame {
555            composition_time, ..
556        } = h264
557        {
558            assert_eq!(composition_time, -256);
559        }
560    }
561
562    #[test]
563    fn test_h264_data_invalid_packet_type() {
564        let data = Bytes::from_static(&[
565            0x03, // Invalid packet type
566            0x00, 0x00, 0x00,
567        ]);
568
569        let result = H264Data::parse(data);
570        assert!(result.is_err());
571    }
572
573    #[test]
574    fn test_h264_data_too_short() {
575        let data = Bytes::from_static(&[0x00, 0x00]); // Less than 4 bytes
576        let result = H264Data::parse(data);
577        assert!(result.is_err());
578    }
579
580    #[test]
581    fn test_avc_config_profile_names() {
582        // Test various profile names
583        let profiles = [
584            (66, "Baseline"),
585            (77, "Main"),
586            (88, "Extended"),
587            (100, "High"),
588            (110, "High 10"),
589            (122, "High 4:2:2"),
590            (244, "High 4:4:4"),
591            (99, "Unknown"),
592        ];
593
594        for (profile, expected_name) in profiles {
595            let config = AvcConfig {
596                profile,
597                compatibility: 0,
598                level: 31,
599                nalu_length_size: 4,
600                sps: vec![],
601                pps: vec![],
602            };
603            assert_eq!(config.profile_name(), expected_name);
604        }
605    }
606
607    #[test]
608    fn test_avc_config_level_string() {
609        let config = AvcConfig {
610            profile: 100,
611            compatibility: 0,
612            level: 41, // Level 4.1
613            nalu_length_size: 4,
614            sps: vec![],
615            pps: vec![],
616        };
617        assert_eq!(config.level_string(), "4.1");
618
619        let config2 = AvcConfig {
620            profile: 100,
621            compatibility: 0,
622            level: 52, // Level 5.2
623            nalu_length_size: 4,
624            sps: vec![],
625            pps: vec![],
626        };
627        assert_eq!(config2.level_string(), "5.2");
628    }
629
630    #[test]
631    fn test_avc_config_invalid_version() {
632        let data = Bytes::from_static(&[
633            0x02, // Invalid version (should be 1)
634            0x64, 0x00, 0x1F, 0xFF, 0xE1, 0x00, 0x04, 0x67, 0x64, 0x00, 0x1F, 0x01, 0x00, 0x03,
635            0x68, 0xEF, 0x38,
636        ]);
637
638        let result = AvcConfig::parse(data);
639        assert!(result.is_err());
640    }
641
642    #[test]
643    fn test_avc_config_too_short() {
644        let data = Bytes::from_static(&[0x01, 0x64, 0x00]); // Less than 7 bytes
645        let result = AvcConfig::parse(data);
646        assert!(result.is_err());
647    }
648
649    #[test]
650    fn test_nalu_iterator() {
651        // Create AVCC-format data with multiple NALUs
652        let data: &[u8] = &[
653            0x00, 0x00, 0x00, 0x03, // length = 3
654            0x67, 0x64, 0x00, // SPS NALU
655            0x00, 0x00, 0x00, 0x02, // length = 2
656            0x68, 0xEF, // PPS NALU
657        ];
658
659        let mut iter = NaluIterator::new(data, 4);
660
661        let nalu1 = iter.next().unwrap();
662        assert_eq!(nalu1.len(), 3);
663        assert_eq!(NaluType::from_byte(nalu1[0]), Some(NaluType::Sps));
664
665        let nalu2 = iter.next().unwrap();
666        assert_eq!(nalu2.len(), 2);
667        assert_eq!(NaluType::from_byte(nalu2[0]), Some(NaluType::Pps));
668
669        assert!(iter.next().is_none());
670    }
671
672    #[test]
673    fn test_nalu_iterator_different_length_sizes() {
674        // Test with 2-byte length prefix
675        let data: &[u8] = &[
676            0x00, 0x02, // length = 2
677            0x65, 0x88, // IDR NALU
678        ];
679
680        let mut iter = NaluIterator::new(data, 2);
681        let nalu = iter.next().unwrap();
682        assert_eq!(nalu.len(), 2);
683    }
684
685    #[test]
686    fn test_nalu_iterator_empty() {
687        let data: &[u8] = &[];
688        let mut iter = NaluIterator::new(data, 4);
689        assert!(iter.next().is_none());
690    }
691
692    #[test]
693    fn test_nalu_iterator_truncated() {
694        // Length says 10 bytes but only 3 available
695        let data: &[u8] = &[
696            0x00, 0x00, 0x00, 0x0A, // length = 10
697            0x67, 0x64, 0x00, // Only 3 bytes
698        ];
699
700        let mut iter = NaluIterator::new(data, 4);
701        assert!(iter.next().is_none()); // Should return None for truncated data
702    }
703}