cros_codecs/codec/h264/
nalu_writer.rs

1// Copyright 2024 The ChromiumOS Authors
2// Use of this source code is governed by a BSD-style license that can be
3// found in the LICENSE file.
4use std::fmt;
5use std::io::Write;
6
7use crate::bitstream_utils::BitWriter;
8use crate::bitstream_utils::BitWriterError;
9
10/// Internal wrapper over [`std::io::Write`] for possible emulation prevention
11struct EmulationPrevention<W: Write> {
12    out: W,
13    prev_bytes: [Option<u8>; 2],
14
15    /// Emulation prevention enabled.
16    ep_enabled: bool,
17}
18
19impl<W: Write> EmulationPrevention<W> {
20    fn new(writer: W, ep_enabled: bool) -> Self {
21        Self { out: writer, prev_bytes: [None; 2], ep_enabled }
22    }
23
24    fn write_byte(&mut self, curr_byte: u8) -> std::io::Result<()> {
25        if self.prev_bytes[1] == Some(0x00) && self.prev_bytes[0] == Some(0x00) && curr_byte <= 0x03
26        {
27            self.out.write_all(&[0x00, 0x00, 0x03, curr_byte])?;
28            self.prev_bytes = [None; 2];
29        } else {
30            if let Some(byte) = self.prev_bytes[1] {
31                self.out.write_all(&[byte])?;
32            }
33
34            self.prev_bytes[1] = self.prev_bytes[0];
35            self.prev_bytes[0] = Some(curr_byte);
36        }
37
38        Ok(())
39    }
40
41    /// Writes a H.264 NALU header.
42    fn write_header(&mut self, idc: u8, type_: u8) -> NaluWriterResult<()> {
43        self.out.write_all(&[0x00, 0x00, 0x00, 0x01, (idc & 0b11) << 5 | (type_ & 0b11111)])?;
44
45        Ok(())
46    }
47
48    fn has_data_pending(&self) -> bool {
49        self.prev_bytes[0].is_some() || self.prev_bytes[1].is_some()
50    }
51}
52
53impl<W: Write> Write for EmulationPrevention<W> {
54    fn write(&mut self, buf: &[u8]) -> std::io::Result<usize> {
55        if !self.ep_enabled {
56            self.out.write_all(buf)?;
57            return Ok(buf.len());
58        }
59
60        for byte in buf {
61            self.write_byte(*byte)?;
62        }
63
64        Ok(buf.len())
65    }
66
67    fn flush(&mut self) -> std::io::Result<()> {
68        if let Some(byte) = self.prev_bytes[1].take() {
69            self.out.write_all(&[byte])?;
70        }
71
72        if let Some(byte) = self.prev_bytes[0].take() {
73            self.out.write_all(&[byte])?;
74        }
75
76        self.out.flush()
77    }
78}
79
80impl<W: Write> Drop for EmulationPrevention<W> {
81    fn drop(&mut self) {
82        if let Err(e) = self.flush() {
83            log::error!("Unable to flush pending bytes {e:?}");
84        }
85    }
86}
87
88#[derive(Debug)]
89pub enum NaluWriterError {
90    Overflow,
91    Io(std::io::Error),
92    BitWriterError(BitWriterError),
93}
94
95impl fmt::Display for NaluWriterError {
96    fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
97        match self {
98            NaluWriterError::Overflow => write!(f, "value increment caused value overflow"),
99            NaluWriterError::Io(x) => write!(f, "{}", x.to_string()),
100            NaluWriterError::BitWriterError(x) => write!(f, "{}", x.to_string()),
101        }
102    }
103}
104
105impl From<std::io::Error> for NaluWriterError {
106    fn from(err: std::io::Error) -> Self {
107        NaluWriterError::Io(err)
108    }
109}
110
111impl From<BitWriterError> for NaluWriterError {
112    fn from(err: BitWriterError) -> Self {
113        NaluWriterError::BitWriterError(err)
114    }
115}
116
117pub type NaluWriterResult<T> = std::result::Result<T, NaluWriterError>;
118
119/// A writer for H.264 bitstream. It is capable of outputing bitstream with
120/// emulation-prevention.
121pub struct NaluWriter<W: Write>(BitWriter<EmulationPrevention<W>>);
122
123impl<W: Write> NaluWriter<W> {
124    pub fn new(writer: W, ep_enabled: bool) -> Self {
125        Self(BitWriter::new(EmulationPrevention::new(writer, ep_enabled)))
126    }
127
128    /// Writes fixed bit size integer (up to 32 bit) output with emulation
129    /// prevention if enabled. Corresponds to `f(n)` in H.264 spec.
130    pub fn write_f<T: Into<u32>>(&mut self, bits: usize, value: T) -> NaluWriterResult<usize> {
131        self.0.write_f(bits, value).map_err(NaluWriterError::BitWriterError)
132    }
133
134    /// An alias to [`Self::write_f`] Corresponds to `n(n)` in H.264 spec.
135    pub fn write_u<T: Into<u32>>(&mut self, bits: usize, value: T) -> NaluWriterResult<usize> {
136        self.write_f(bits, value)
137    }
138
139    /// Writes a number in exponential golumb format.
140    pub fn write_exp_golumb(&mut self, value: u32) -> NaluWriterResult<()> {
141        let value = value.checked_add(1).ok_or(NaluWriterError::Overflow)?;
142        let bits = 32 - value.leading_zeros() as usize;
143        let zeros = bits - 1;
144
145        self.write_f(zeros, 0u32)?;
146        self.write_f(bits, value)?;
147
148        Ok(())
149    }
150
151    /// Writes a unsigned integer in exponential golumb format.
152    /// Coresponds to `ue(v)` in H.264 spec.
153    pub fn write_ue<T: Into<u32>>(&mut self, value: T) -> NaluWriterResult<()> {
154        let value = value.into();
155
156        self.write_exp_golumb(value)
157    }
158
159    /// Writes a signed integer in exponential golumb format.
160    /// Coresponds to `se(v)` in H.264 spec.
161    pub fn write_se<T: Into<i32>>(&mut self, value: T) -> NaluWriterResult<()> {
162        let value: i32 = value.into();
163        let abs_value: u32 = value.unsigned_abs();
164
165        if value <= 0 {
166            self.write_ue(2 * abs_value)
167        } else {
168            self.write_ue(2 * abs_value - 1)
169        }
170    }
171
172    /// Returns `true` if ['Self`] hold data that wasn't written to [`std::io::Write`]
173    pub fn has_data_pending(&self) -> bool {
174        self.0.has_data_pending() || self.0.inner().has_data_pending()
175    }
176
177    /// Writes a H.264 NALU header.
178    pub fn write_header(&mut self, idc: u8, _type: u8) -> NaluWriterResult<()> {
179        self.0.flush()?;
180        self.0.inner_mut().write_header(idc, _type)?;
181        Ok(())
182    }
183
184    /// Returns `true` if next bits will be aligned to 8
185    pub fn aligned(&self) -> bool {
186        !self.0.has_data_pending()
187    }
188}
189
190#[cfg(test)]
191mod tests {
192    use super::*;
193    use crate::bitstream_utils::BitReader;
194
195    #[test]
196    fn simple_bits() {
197        let mut buf = Vec::<u8>::new();
198        {
199            let mut writer = NaluWriter::new(&mut buf, false);
200            writer.write_f(1, true).unwrap();
201            writer.write_f(1, false).unwrap();
202            writer.write_f(1, false).unwrap();
203            writer.write_f(1, false).unwrap();
204            writer.write_f(1, true).unwrap();
205            writer.write_f(1, true).unwrap();
206            writer.write_f(1, true).unwrap();
207            writer.write_f(1, true).unwrap();
208        }
209        assert_eq!(buf, vec![0b10001111u8]);
210    }
211
212    #[test]
213    fn simple_first_few_ue() {
214        fn single_ue(value: u32) -> Vec<u8> {
215            let mut buf = Vec::<u8>::new();
216            {
217                let mut writer = NaluWriter::new(&mut buf, false);
218                writer.write_ue(value).unwrap();
219            }
220            buf
221        }
222
223        assert_eq!(single_ue(0), vec![0b10000000u8]);
224        assert_eq!(single_ue(1), vec![0b01000000u8]);
225        assert_eq!(single_ue(2), vec![0b01100000u8]);
226        assert_eq!(single_ue(3), vec![0b00100000u8]);
227        assert_eq!(single_ue(4), vec![0b00101000u8]);
228        assert_eq!(single_ue(5), vec![0b00110000u8]);
229        assert_eq!(single_ue(6), vec![0b00111000u8]);
230        assert_eq!(single_ue(7), vec![0b00010000u8]);
231        assert_eq!(single_ue(8), vec![0b00010010u8]);
232        assert_eq!(single_ue(9), vec![0b00010100u8]);
233    }
234
235    #[test]
236    fn writer_reader() {
237        let mut buf = Vec::<u8>::new();
238        {
239            let mut writer = NaluWriter::new(&mut buf, false);
240            writer.write_ue(10u32).unwrap();
241            writer.write_se(-42).unwrap();
242            writer.write_se(3).unwrap();
243            writer.write_ue(5u32).unwrap();
244        }
245
246        let mut reader = BitReader::new(&buf, true);
247
248        assert_eq!(reader.read_ue::<u32>().unwrap(), 10);
249        assert_eq!(reader.read_se::<i32>().unwrap(), -42);
250        assert_eq!(reader.read_se::<i32>().unwrap(), 3);
251        assert_eq!(reader.read_ue::<u32>().unwrap(), 5);
252
253        let mut buf = Vec::<u8>::new();
254        {
255            let mut writer = NaluWriter::new(&mut buf, false);
256            writer.write_se(30).unwrap();
257            writer.write_ue(100u32).unwrap();
258            writer.write_se(-402).unwrap();
259            writer.write_ue(50u32).unwrap();
260        }
261
262        let mut reader = BitReader::new(&buf, true);
263
264        assert_eq!(reader.read_se::<i32>().unwrap(), 30);
265        assert_eq!(reader.read_ue::<u32>().unwrap(), 100);
266        assert_eq!(reader.read_se::<i32>().unwrap(), -402);
267        assert_eq!(reader.read_ue::<u32>().unwrap(), 50);
268    }
269
270    #[test]
271    fn writer_emulation_prevention() {
272        fn test(input: &[u8], bitstream: &[u8]) {
273            let mut buf = Vec::<u8>::new();
274            {
275                let mut writer = NaluWriter::new(&mut buf, true);
276                for byte in input {
277                    writer.write_f(8, *byte).unwrap();
278                }
279            }
280            assert_eq!(buf, bitstream);
281            {
282                let mut reader = BitReader::new(&buf, true);
283                for byte in input {
284                    assert_eq!(*byte, reader.read_bits::<u8>(8).unwrap());
285                }
286            }
287        }
288
289        test(&[0x00, 0x00, 0x00], &[0x00, 0x00, 0x03, 0x00]);
290        test(&[0x00, 0x00, 0x01], &[0x00, 0x00, 0x03, 0x01]);
291        test(&[0x00, 0x00, 0x02], &[0x00, 0x00, 0x03, 0x02]);
292        test(&[0x00, 0x00, 0x03], &[0x00, 0x00, 0x03, 0x03]);
293
294        test(&[0x00, 0x00, 0x00, 0x00], &[0x00, 0x00, 0x03, 0x00, 0x00]);
295        test(&[0x00, 0x00, 0x00, 0x01], &[0x00, 0x00, 0x03, 0x00, 0x01]);
296        test(&[0x00, 0x00, 0x00, 0x02], &[0x00, 0x00, 0x03, 0x00, 0x02]);
297        test(&[0x00, 0x00, 0x00, 0x03], &[0x00, 0x00, 0x03, 0x00, 0x03]);
298    }
299}