Skip to main content

moq_vaapi/codec/h264/
nalu_writer.rs

1// Copyright 2024 The ChromiumOS Authors
2// Use of this source code is governed by a BSD-style license that can be
3// found in the LICENSE file.
4use std::fmt;
5use std::io::Write;
6
7use crate::bitstream_utils::BitWriter;
8use crate::bitstream_utils::BitWriterError;
9
10/// Internal wrapper over [`std::io::Write`] for possible emulation prevention
11struct EmulationPrevention<W: Write> {
12	out: W,
13	prev_bytes: [Option<u8>; 2],
14
15	/// Emulation prevention enabled.
16	ep_enabled: bool,
17}
18
19impl<W: Write> EmulationPrevention<W> {
20	fn new(writer: W, ep_enabled: bool) -> Self {
21		Self {
22			out: writer,
23			prev_bytes: [None; 2],
24			ep_enabled,
25		}
26	}
27
28	fn write_byte(&mut self, curr_byte: u8) -> std::io::Result<()> {
29		if self.prev_bytes[1] == Some(0x00) && self.prev_bytes[0] == Some(0x00) && curr_byte <= 0x03 {
30			self.out.write_all(&[0x00, 0x00, 0x03, curr_byte])?;
31			self.prev_bytes = [None; 2];
32		} else {
33			if let Some(byte) = self.prev_bytes[1] {
34				self.out.write_all(&[byte])?;
35			}
36
37			self.prev_bytes[1] = self.prev_bytes[0];
38			self.prev_bytes[0] = Some(curr_byte);
39		}
40
41		Ok(())
42	}
43
44	/// Writes a H.264 NALU header.
45	fn write_header(&mut self, idc: u8, type_: u8) -> NaluWriterResult<()> {
46		self.out
47			.write_all(&[0x00, 0x00, 0x00, 0x01, (idc & 0b11) << 5 | (type_ & 0b11111)])?;
48
49		Ok(())
50	}
51
52	fn has_data_pending(&self) -> bool {
53		self.prev_bytes[0].is_some() || self.prev_bytes[1].is_some()
54	}
55}
56
57impl<W: Write> Write for EmulationPrevention<W> {
58	fn write(&mut self, buf: &[u8]) -> std::io::Result<usize> {
59		if !self.ep_enabled {
60			self.out.write_all(buf)?;
61			return Ok(buf.len());
62		}
63
64		for byte in buf {
65			self.write_byte(*byte)?;
66		}
67
68		Ok(buf.len())
69	}
70
71	fn flush(&mut self) -> std::io::Result<()> {
72		if let Some(byte) = self.prev_bytes[1].take() {
73			self.out.write_all(&[byte])?;
74		}
75
76		if let Some(byte) = self.prev_bytes[0].take() {
77			self.out.write_all(&[byte])?;
78		}
79
80		self.out.flush()
81	}
82}
83
84impl<W: Write> Drop for EmulationPrevention<W> {
85	fn drop(&mut self) {
86		if let Err(e) = self.flush() {
87			log::error!("Unable to flush pending bytes {e:?}");
88		}
89	}
90}
91
92#[derive(Debug)]
93pub enum NaluWriterError {
94	Overflow,
95	Io(std::io::Error),
96	BitWriterError(BitWriterError),
97}
98
99impl fmt::Display for NaluWriterError {
100	fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
101		match self {
102			NaluWriterError::Overflow => write!(f, "value increment caused value overflow"),
103			NaluWriterError::Io(x) => write!(f, "{}", x.to_string()),
104			NaluWriterError::BitWriterError(x) => write!(f, "{}", x.to_string()),
105		}
106	}
107}
108
109impl From<std::io::Error> for NaluWriterError {
110	fn from(err: std::io::Error) -> Self {
111		NaluWriterError::Io(err)
112	}
113}
114
115impl From<BitWriterError> for NaluWriterError {
116	fn from(err: BitWriterError) -> Self {
117		NaluWriterError::BitWriterError(err)
118	}
119}
120
121pub type NaluWriterResult<T> = std::result::Result<T, NaluWriterError>;
122
123/// A writer for H.264 bitstream. It is capable of outputing bitstream with
124/// emulation-prevention.
125pub struct NaluWriter<W: Write>(BitWriter<EmulationPrevention<W>>);
126
127impl<W: Write> NaluWriter<W> {
128	pub fn new(writer: W, ep_enabled: bool) -> Self {
129		Self(BitWriter::new(EmulationPrevention::new(writer, ep_enabled)))
130	}
131
132	/// Writes fixed bit size integer (up to 32 bit) output with emulation
133	/// prevention if enabled. Corresponds to `f(n)` in H.264 spec.
134	pub fn write_f<T: Into<u32>>(&mut self, bits: usize, value: T) -> NaluWriterResult<usize> {
135		self.0.write_f(bits, value).map_err(NaluWriterError::BitWriterError)
136	}
137
138	/// An alias to [`Self::write_f`] Corresponds to `n(n)` in H.264 spec.
139	pub fn write_u<T: Into<u32>>(&mut self, bits: usize, value: T) -> NaluWriterResult<usize> {
140		self.write_f(bits, value)
141	}
142
143	/// Writes a number in exponential golumb format.
144	pub fn write_exp_golumb(&mut self, value: u32) -> NaluWriterResult<()> {
145		let value = value.checked_add(1).ok_or(NaluWriterError::Overflow)?;
146		let bits = 32 - value.leading_zeros() as usize;
147		let zeros = bits - 1;
148
149		self.write_f(zeros, 0u32)?;
150		self.write_f(bits, value)?;
151
152		Ok(())
153	}
154
155	/// Writes a unsigned integer in exponential golumb format.
156	/// Coresponds to `ue(v)` in H.264 spec.
157	pub fn write_ue<T: Into<u32>>(&mut self, value: T) -> NaluWriterResult<()> {
158		let value = value.into();
159
160		self.write_exp_golumb(value)
161	}
162
163	/// Writes a signed integer in exponential golumb format.
164	/// Coresponds to `se(v)` in H.264 spec.
165	pub fn write_se<T: Into<i32>>(&mut self, value: T) -> NaluWriterResult<()> {
166		let value: i32 = value.into();
167		let abs_value: u32 = value.unsigned_abs();
168
169		if value <= 0 {
170			self.write_ue(2 * abs_value)
171		} else {
172			self.write_ue(2 * abs_value - 1)
173		}
174	}
175
176	/// Returns `true` if ['Self`] hold data that wasn't written to [`std::io::Write`]
177	pub fn has_data_pending(&self) -> bool {
178		self.0.has_data_pending() || self.0.inner().has_data_pending()
179	}
180
181	/// Writes a H.264 NALU header.
182	pub fn write_header(&mut self, idc: u8, _type: u8) -> NaluWriterResult<()> {
183		self.0.flush()?;
184		let _num_bytes = self.0.inner_mut().write_header(idc, _type)?;
185		// self.0.bits_written += num_bytes * 8;
186		Ok(())
187	}
188
189	/// Returns `true` if next bits will be aligned to 8
190	pub fn aligned(&self) -> bool {
191		!self.0.has_data_pending()
192	}
193
194	/// Returns the number of trailing bits in the last byte.
195	pub fn flush(&mut self) -> NaluWriterResult<u8> {
196		Ok(self.0.flush()?)
197	}
198}