h264_parser/
nal.rs

1use crate::{Error, Result};
2
3#[derive(Debug, Clone, Copy, PartialEq, Eq)]
4pub enum NalUnitType {
5    Unspecified,
6    NonIdrSlice,
7    DataPartitionA,
8    DataPartitionB,
9    DataPartitionC,
10    IdrSlice,
11    Sei,
12    Sps,
13    Pps,
14    Aud,
15    EndOfSeq,
16    EndOfStream,
17    Filler,
18    SpsExt,
19    Prefix,
20    SubsetSps,
21    DepthParameterSet,
22    Reserved(u8),
23    UnspecifiedExt(u8),
24}
25
26impl NalUnitType {
27    pub fn as_u8(&self) -> u8 {
28        match self {
29            Self::Unspecified => 0,
30            Self::NonIdrSlice => 1,
31            Self::DataPartitionA => 2,
32            Self::DataPartitionB => 3,
33            Self::DataPartitionC => 4,
34            Self::IdrSlice => 5,
35            Self::Sei => 6,
36            Self::Sps => 7,
37            Self::Pps => 8,
38            Self::Aud => 9,
39            Self::EndOfSeq => 10,
40            Self::EndOfStream => 11,
41            Self::Filler => 12,
42            Self::SpsExt => 13,
43            Self::Prefix => 14,
44            Self::SubsetSps => 15,
45            Self::DepthParameterSet => 16,
46            Self::Reserved(v) => *v,
47            Self::UnspecifiedExt(v) => *v,
48        }
49    }
50}
51
52impl From<u8> for NalUnitType {
53    fn from(value: u8) -> Self {
54        match value {
55            0 => Self::Unspecified,
56            1 => Self::NonIdrSlice,
57            2 => Self::DataPartitionA,
58            3 => Self::DataPartitionB,
59            4 => Self::DataPartitionC,
60            5 => Self::IdrSlice,
61            6 => Self::Sei,
62            7 => Self::Sps,
63            8 => Self::Pps,
64            9 => Self::Aud,
65            10 => Self::EndOfSeq,
66            11 => Self::EndOfStream,
67            12 => Self::Filler,
68            13 => Self::SpsExt,
69            14 => Self::Prefix,
70            15 => Self::SubsetSps,
71            16 => Self::DepthParameterSet,
72            17..=23 => Self::Reserved(value),
73            24..=31 => Self::UnspecifiedExt(value),
74            _ => Self::Unspecified,
75        }
76    }
77}
78
79#[derive(Debug, Clone, PartialEq, Eq)]
80pub struct Nal {
81    pub start_code_len: u8,
82    pub ref_idc: u8,
83    pub nal_type: NalUnitType,
84    pub ebsp: Vec<u8>,
85}
86
87impl Nal {
88    pub fn parse(start_code_len: u8, data: &[u8]) -> Result<Self> {
89        if data.is_empty() {
90            return Err(Error::InvalidNalHeader);
91        }
92
93        let header = data[0];
94
95        let forbidden_zero_bit = (header >> 7) & 1;
96        if forbidden_zero_bit != 0 {
97            return Err(Error::InvalidNalHeader);
98        }
99
100        let ref_idc = (header >> 5) & 0b11;
101        let nal_unit_type = header & 0b11111;
102        let nal_type = NalUnitType::from(nal_unit_type);
103
104        let ebsp = if data.len() > 1 {
105            data[1..].to_vec()
106        } else {
107            Vec::new()
108        };
109
110        Ok(Nal {
111            start_code_len,
112            ref_idc,
113            nal_type,
114            ebsp,
115        })
116    }
117
118    pub fn to_rbsp(&self) -> Vec<u8> {
119        ebsp_to_rbsp(&self.ebsp)
120    }
121
122    pub fn is_slice(&self) -> bool {
123        matches!(
124            self.nal_type,
125            NalUnitType::NonIdrSlice
126                | NalUnitType::IdrSlice
127                | NalUnitType::DataPartitionA
128                | NalUnitType::DataPartitionB
129                | NalUnitType::DataPartitionC
130        )
131    }
132
133    pub fn is_vcl(&self) -> bool {
134        match self.nal_type {
135            NalUnitType::NonIdrSlice
136            | NalUnitType::DataPartitionA
137            | NalUnitType::DataPartitionB
138            | NalUnitType::DataPartitionC
139            | NalUnitType::IdrSlice => true,
140            _ => false,
141        }
142    }
143}
144
145pub fn ebsp_to_rbsp(ebsp: &[u8]) -> Vec<u8> {
146    let mut rbsp = Vec::with_capacity(ebsp.len());
147    let mut i = 0;
148
149    while i < ebsp.len() {
150        if i + 2 < ebsp.len() && ebsp[i] == 0x00 && ebsp[i + 1] == 0x00 && ebsp[i + 2] == 0x03 {
151            rbsp.push(0x00);
152            rbsp.push(0x00);
153            i += 3;
154        } else {
155            rbsp.push(ebsp[i]);
156            i += 1;
157        }
158    }
159
160    rbsp
161}
162
163pub fn rbsp_to_ebsp(rbsp: &[u8]) -> Vec<u8> {
164    let mut ebsp = Vec::with_capacity(rbsp.len() + rbsp.len() / 3);
165    let mut zero_count = 0;
166
167    for &byte in rbsp {
168        if zero_count == 2 && byte <= 0x03 {
169            ebsp.push(0x03);
170            zero_count = 0;
171        }
172
173        ebsp.push(byte);
174
175        if byte == 0x00 {
176            zero_count += 1;
177        } else {
178            zero_count = 0;
179        }
180    }
181
182    ebsp
183}
184
185#[cfg(test)]
186mod tests {
187    use super::*;
188
189    #[test]
190    fn test_nal_parse() {
191        let data = vec![0x67, 0x42, 0x00, 0x1f];
192        let nal = Nal::parse(4, &data).unwrap();
193
194        assert_eq!(nal.ref_idc, 3);
195        assert_eq!(nal.nal_type, NalUnitType::Sps);
196        assert_eq!(nal.ebsp, &[0x42, 0x00, 0x1f]);
197    }
198
199    #[test]
200    fn test_ebsp_to_rbsp() {
201        let ebsp = vec![
202            0x00, 0x00, 0x03, 0x00, 0x00, 0x03, 0x01, 0x00, 0x00, 0x03, 0x02,
203        ];
204        let rbsp = ebsp_to_rbsp(&ebsp);
205        assert_eq!(rbsp, vec![0x00, 0x00, 0x00, 0x00, 0x01, 0x00, 0x00, 0x02]);
206    }
207
208    #[test]
209    fn test_rbsp_to_ebsp() {
210        let rbsp = vec![0x00, 0x00, 0x00, 0x00, 0x01, 0x00, 0x00, 0x02];
211        let ebsp = rbsp_to_ebsp(&rbsp);
212        assert_eq!(
213            ebsp,
214            vec![0x00, 0x00, 0x03, 0x00, 0x00, 0x03, 0x01, 0x00, 0x00, 0x03, 0x02]
215        );
216    }
217
218    #[test]
219    fn test_nal_type_conversion() {
220        assert_eq!(NalUnitType::from(5), NalUnitType::IdrSlice);
221        assert_eq!(NalUnitType::from(7), NalUnitType::Sps);
222        assert_eq!(NalUnitType::from(8), NalUnitType::Pps);
223        assert!(matches!(NalUnitType::from(20), NalUnitType::Reserved(20)));
224    }
225}