1use std::fmt;
5use std::io::Write;
6
7use crate::bitstream_utils::BitWriter;
8use crate::bitstream_utils::BitWriterError;
9
10struct EmulationPrevention<W: Write> {
12 out: W,
13 prev_bytes: [Option<u8>; 2],
14
15 ep_enabled: bool,
17}
18
19impl<W: Write> EmulationPrevention<W> {
20 fn new(writer: W, ep_enabled: bool) -> Self {
21 Self { out: writer, prev_bytes: [None; 2], ep_enabled }
22 }
23
24 fn write_byte(&mut self, curr_byte: u8) -> std::io::Result<()> {
25 if self.prev_bytes[1] == Some(0x00) && self.prev_bytes[0] == Some(0x00) && curr_byte <= 0x03
26 {
27 self.out.write_all(&[0x00, 0x00, 0x03, curr_byte])?;
28 self.prev_bytes = [None; 2];
29 } else {
30 if let Some(byte) = self.prev_bytes[1] {
31 self.out.write_all(&[byte])?;
32 }
33
34 self.prev_bytes[1] = self.prev_bytes[0];
35 self.prev_bytes[0] = Some(curr_byte);
36 }
37
38 Ok(())
39 }
40
41 fn write_header(&mut self, idc: u8, type_: u8) -> NaluWriterResult<()> {
43 self.out.write_all(&[0x00, 0x00, 0x00, 0x01, (idc & 0b11) << 5 | (type_ & 0b11111)])?;
44
45 Ok(())
46 }
47
48 fn has_data_pending(&self) -> bool {
49 self.prev_bytes[0].is_some() || self.prev_bytes[1].is_some()
50 }
51}
52
53impl<W: Write> Write for EmulationPrevention<W> {
54 fn write(&mut self, buf: &[u8]) -> std::io::Result<usize> {
55 if !self.ep_enabled {
56 self.out.write_all(buf)?;
57 return Ok(buf.len());
58 }
59
60 for byte in buf {
61 self.write_byte(*byte)?;
62 }
63
64 Ok(buf.len())
65 }
66
67 fn flush(&mut self) -> std::io::Result<()> {
68 if let Some(byte) = self.prev_bytes[1].take() {
69 self.out.write_all(&[byte])?;
70 }
71
72 if let Some(byte) = self.prev_bytes[0].take() {
73 self.out.write_all(&[byte])?;
74 }
75
76 self.out.flush()
77 }
78}
79
80impl<W: Write> Drop for EmulationPrevention<W> {
81 fn drop(&mut self) {
82 if let Err(e) = self.flush() {
83 log::error!("Unable to flush pending bytes {e:?}");
84 }
85 }
86}
87
88#[derive(Debug)]
89pub enum NaluWriterError {
90 Overflow,
91 Io(std::io::Error),
92 BitWriterError(BitWriterError),
93}
94
95impl fmt::Display for NaluWriterError {
96 fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
97 match self {
98 NaluWriterError::Overflow => write!(f, "value increment caused value overflow"),
99 NaluWriterError::Io(x) => write!(f, "{}", x.to_string()),
100 NaluWriterError::BitWriterError(x) => write!(f, "{}", x.to_string()),
101 }
102 }
103}
104
105impl From<std::io::Error> for NaluWriterError {
106 fn from(err: std::io::Error) -> Self {
107 NaluWriterError::Io(err)
108 }
109}
110
111impl From<BitWriterError> for NaluWriterError {
112 fn from(err: BitWriterError) -> Self {
113 NaluWriterError::BitWriterError(err)
114 }
115}
116
117pub type NaluWriterResult<T> = std::result::Result<T, NaluWriterError>;
118
119pub struct NaluWriter<W: Write>(BitWriter<EmulationPrevention<W>>);
122
123impl<W: Write> NaluWriter<W> {
124 pub fn new(writer: W, ep_enabled: bool) -> Self {
125 Self(BitWriter::new(EmulationPrevention::new(writer, ep_enabled)))
126 }
127
128 pub fn write_f<T: Into<u32>>(&mut self, bits: usize, value: T) -> NaluWriterResult<usize> {
131 self.0.write_f(bits, value).map_err(NaluWriterError::BitWriterError)
132 }
133
134 pub fn write_u<T: Into<u32>>(&mut self, bits: usize, value: T) -> NaluWriterResult<usize> {
136 self.write_f(bits, value)
137 }
138
139 pub fn write_exp_golumb(&mut self, value: u32) -> NaluWriterResult<()> {
141 let value = value.checked_add(1).ok_or(NaluWriterError::Overflow)?;
142 let bits = 32 - value.leading_zeros() as usize;
143 let zeros = bits - 1;
144
145 self.write_f(zeros, 0u32)?;
146 self.write_f(bits, value)?;
147
148 Ok(())
149 }
150
151 pub fn write_ue<T: Into<u32>>(&mut self, value: T) -> NaluWriterResult<()> {
154 let value = value.into();
155
156 self.write_exp_golumb(value)
157 }
158
159 pub fn write_se<T: Into<i32>>(&mut self, value: T) -> NaluWriterResult<()> {
162 let value: i32 = value.into();
163 let abs_value: u32 = value.unsigned_abs();
164
165 if value <= 0 {
166 self.write_ue(2 * abs_value)
167 } else {
168 self.write_ue(2 * abs_value - 1)
169 }
170 }
171
172 pub fn has_data_pending(&self) -> bool {
174 self.0.has_data_pending() || self.0.inner().has_data_pending()
175 }
176
177 pub fn write_header(&mut self, idc: u8, _type: u8) -> NaluWriterResult<()> {
179 self.0.flush()?;
180 self.0.inner_mut().write_header(idc, _type)?;
181 Ok(())
182 }
183
184 pub fn aligned(&self) -> bool {
186 !self.0.has_data_pending()
187 }
188}
189
190#[cfg(test)]
191mod tests {
192 use super::*;
193 use crate::bitstream_utils::BitReader;
194
195 #[test]
196 fn simple_bits() {
197 let mut buf = Vec::<u8>::new();
198 {
199 let mut writer = NaluWriter::new(&mut buf, false);
200 writer.write_f(1, true).unwrap();
201 writer.write_f(1, false).unwrap();
202 writer.write_f(1, false).unwrap();
203 writer.write_f(1, false).unwrap();
204 writer.write_f(1, true).unwrap();
205 writer.write_f(1, true).unwrap();
206 writer.write_f(1, true).unwrap();
207 writer.write_f(1, true).unwrap();
208 }
209 assert_eq!(buf, vec![0b10001111u8]);
210 }
211
212 #[test]
213 fn simple_first_few_ue() {
214 fn single_ue(value: u32) -> Vec<u8> {
215 let mut buf = Vec::<u8>::new();
216 {
217 let mut writer = NaluWriter::new(&mut buf, false);
218 writer.write_ue(value).unwrap();
219 }
220 buf
221 }
222
223 assert_eq!(single_ue(0), vec![0b10000000u8]);
224 assert_eq!(single_ue(1), vec![0b01000000u8]);
225 assert_eq!(single_ue(2), vec![0b01100000u8]);
226 assert_eq!(single_ue(3), vec![0b00100000u8]);
227 assert_eq!(single_ue(4), vec![0b00101000u8]);
228 assert_eq!(single_ue(5), vec![0b00110000u8]);
229 assert_eq!(single_ue(6), vec![0b00111000u8]);
230 assert_eq!(single_ue(7), vec![0b00010000u8]);
231 assert_eq!(single_ue(8), vec![0b00010010u8]);
232 assert_eq!(single_ue(9), vec![0b00010100u8]);
233 }
234
235 #[test]
236 fn writer_reader() {
237 let mut buf = Vec::<u8>::new();
238 {
239 let mut writer = NaluWriter::new(&mut buf, false);
240 writer.write_ue(10u32).unwrap();
241 writer.write_se(-42).unwrap();
242 writer.write_se(3).unwrap();
243 writer.write_ue(5u32).unwrap();
244 }
245
246 let mut reader = BitReader::new(&buf, true);
247
248 assert_eq!(reader.read_ue::<u32>().unwrap(), 10);
249 assert_eq!(reader.read_se::<i32>().unwrap(), -42);
250 assert_eq!(reader.read_se::<i32>().unwrap(), 3);
251 assert_eq!(reader.read_ue::<u32>().unwrap(), 5);
252
253 let mut buf = Vec::<u8>::new();
254 {
255 let mut writer = NaluWriter::new(&mut buf, false);
256 writer.write_se(30).unwrap();
257 writer.write_ue(100u32).unwrap();
258 writer.write_se(-402).unwrap();
259 writer.write_ue(50u32).unwrap();
260 }
261
262 let mut reader = BitReader::new(&buf, true);
263
264 assert_eq!(reader.read_se::<i32>().unwrap(), 30);
265 assert_eq!(reader.read_ue::<u32>().unwrap(), 100);
266 assert_eq!(reader.read_se::<i32>().unwrap(), -402);
267 assert_eq!(reader.read_ue::<u32>().unwrap(), 50);
268 }
269
270 #[test]
271 fn writer_emulation_prevention() {
272 fn test(input: &[u8], bitstream: &[u8]) {
273 let mut buf = Vec::<u8>::new();
274 {
275 let mut writer = NaluWriter::new(&mut buf, true);
276 for byte in input {
277 writer.write_f(8, *byte).unwrap();
278 }
279 }
280 assert_eq!(buf, bitstream);
281 {
282 let mut reader = BitReader::new(&buf, true);
283 for byte in input {
284 assert_eq!(*byte, reader.read_bits::<u8>(8).unwrap());
285 }
286 }
287 }
288
289 test(&[0x00, 0x00, 0x00], &[0x00, 0x00, 0x03, 0x00]);
290 test(&[0x00, 0x00, 0x01], &[0x00, 0x00, 0x03, 0x01]);
291 test(&[0x00, 0x00, 0x02], &[0x00, 0x00, 0x03, 0x02]);
292 test(&[0x00, 0x00, 0x03], &[0x00, 0x00, 0x03, 0x03]);
293
294 test(&[0x00, 0x00, 0x00, 0x00], &[0x00, 0x00, 0x03, 0x00, 0x00]);
295 test(&[0x00, 0x00, 0x00, 0x01], &[0x00, 0x00, 0x03, 0x00, 0x01]);
296 test(&[0x00, 0x00, 0x00, 0x02], &[0x00, 0x00, 0x03, 0x00, 0x02]);
297 test(&[0x00, 0x00, 0x00, 0x03], &[0x00, 0x00, 0x03, 0x00, 0x03]);
298 }
299}