1use std::borrow::Cow;
6use std::fmt;
7use std::io::Cursor;
8use std::io::Read;
9use std::io::Seek;
10use std::io::SeekFrom;
11use std::io::Write;
12use std::marker::PhantomData;
13
14use crate::codec::h264::parser::Nalu as H264Nalu;
15
16#[derive(Clone)]
19pub(crate) struct BitReader<'a> {
20 data: Cursor<&'a [u8]>,
22 curr_byte: u8,
25 num_remaining_bits_in_curr_byte: usize,
27 prev_two_bytes: u16,
29 num_epb: usize,
31 needs_epb: bool,
33 position: u64,
35}
36
37#[derive(Debug)]
38pub(crate) enum GetByteError {
39 OutOfBits,
40}
41
42impl fmt::Display for GetByteError {
43 fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
44 write!(f, "reader ran out of bits")
45 }
46}
47
48#[derive(Debug)]
49pub(crate) enum ReadBitsError {
50 TooManyBitsRequested(usize),
51 GetByte(GetByteError),
52 ConversionFailed,
53}
54
55impl fmt::Display for ReadBitsError {
56 fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
57 match self {
58 ReadBitsError::TooManyBitsRequested(bits) => {
59 write!(f, "more than 31 ({}) bits were requested", bits)
60 }
61 ReadBitsError::GetByte(_) => write!(f, "failed to advance the current byte"),
62 ReadBitsError::ConversionFailed => {
63 write!(f, "failed to convert read input to target type")
64 }
65 }
66 }
67}
68
69impl From<GetByteError> for ReadBitsError {
70 fn from(err: GetByteError) -> Self {
71 ReadBitsError::GetByte(err)
72 }
73}
74
75impl<'a> BitReader<'a> {
76 pub fn new(data: &'a [u8], needs_epb: bool) -> Self {
77 Self {
78 data: Cursor::new(data),
79 curr_byte: Default::default(),
80 num_remaining_bits_in_curr_byte: Default::default(),
81 prev_two_bytes: 0xffff,
82 num_epb: Default::default(),
83 needs_epb: needs_epb,
84 position: 0,
85 }
86 }
87
88 pub fn read_bit(&mut self) -> Result<bool, String> {
90 let bit = self.read_bits::<u32>(1)?;
91 match bit {
92 1 => Ok(true),
93 0 => Ok(false),
94 _ => panic!("Unexpected value {}", bit),
95 }
96 }
97
98 pub fn read_bits<U: TryFrom<u32>>(&mut self, num_bits: usize) -> Result<U, String> {
103 if num_bits > 31 {
104 return Err(ReadBitsError::TooManyBitsRequested(num_bits).to_string());
105 }
106
107 let mut bits_left = num_bits;
108 let mut out = 0u32;
109
110 while self.num_remaining_bits_in_curr_byte < bits_left {
111 out |= (self.curr_byte as u32) << (bits_left - self.num_remaining_bits_in_curr_byte);
112 bits_left -= self.num_remaining_bits_in_curr_byte;
113 self.move_to_next_byte().map_err(|err| err.to_string())?;
114 }
115
116 out |= (self.curr_byte >> (self.num_remaining_bits_in_curr_byte - bits_left)) as u32;
117 out &= (1 << num_bits) - 1;
118 self.num_remaining_bits_in_curr_byte -= bits_left;
119 self.position += num_bits as u64;
120
121 U::try_from(out).map_err(|_| ReadBitsError::ConversionFailed.to_string())
122 }
123
124 pub fn read_bits_signed<U: TryFrom<i32>>(&mut self, num_bits: usize) -> Result<U, String> {
126 let mut out: i32 = self
127 .read_bits::<u32>(num_bits)?
128 .try_into()
129 .map_err(|_| ReadBitsError::ConversionFailed.to_string())?;
130 if out >> (num_bits - 1) != 0 {
131 out |= -1i32 ^ ((1 << num_bits) - 1);
132 }
133
134 U::try_from(out).map_err(|_| ReadBitsError::ConversionFailed.to_string())
135 }
136
137 pub fn read_bits_aligned<U: TryFrom<u32>>(&mut self, num_bits: usize) -> Result<U, String> {
139 if self.num_remaining_bits_in_curr_byte % 8 != 0 {
140 return Err("Attempted unaligned read_le()".into());
141 }
142
143 Ok(self.read_bits(num_bits).map_err(|err| err.to_string())?)
144 }
145
146 pub fn skip_bits(&mut self, mut num_bits: usize) -> Result<(), String> {
148 while num_bits > 0 {
149 let n = std::cmp::min(num_bits, 31);
150 self.read_bits::<u32>(n)?;
151 num_bits -= n;
152 }
153
154 Ok(())
155 }
156
157 pub fn num_bits_left(&mut self) -> usize {
159 let cur_pos = self.data.position();
160 let end_pos = self.data.seek(SeekFrom::End(0)).unwrap();
162 let _ = self.data.seek(SeekFrom::Start(cur_pos));
163 ((end_pos - cur_pos) as usize) * 8 + self.num_remaining_bits_in_curr_byte
164 }
165
166 pub fn num_epb(&self) -> usize {
168 self.num_epb
169 }
170
171 pub fn has_more_rsbp_data(&mut self) -> bool {
174 if self.num_remaining_bits_in_curr_byte == 0 && self.move_to_next_byte().is_err() {
175 return false;
177 }
178
179 if (self.curr_byte & ((1 << (self.num_remaining_bits_in_curr_byte - 1)) - 1)) != 0 {
182 return true;
183 }
184
185 let mut buf = [0u8; 1];
186 let orig_pos = self.data.position();
187 while let Ok(_) = self.data.read_exact(&mut buf) {
188 if buf[0] != 0 {
189 self.data.set_position(orig_pos);
190 return true;
191 }
192 }
193 false
194 }
195
196 pub fn read_ue<U: TryFrom<u32>>(&mut self) -> Result<U, String> {
200 let mut num_bits = 0;
201
202 while self.read_bits::<u32>(1)? == 0 {
203 num_bits += 1;
204 if num_bits > 31 {
205 return Err("invalid stream".into());
206 }
207 }
208
209 let value = ((1u32 << num_bits) - 1)
210 .checked_add(self.read_bits::<u32>(num_bits)?)
211 .ok_or::<String>("read number cannot fit in 32 bits".into())?;
212
213 U::try_from(value).map_err(|_| "conversion error".into())
214 }
215
216 pub fn read_ue_bounded<U: TryFrom<u32>>(&mut self, min: u32, max: u32) -> Result<U, String> {
217 let ue = self.read_ue()?;
218 if ue > max || ue < min {
219 Err(format!("Value out of bounds: expected {} - {}, got {}", min, max, ue))
220 } else {
221 Ok(U::try_from(ue).map_err(|_| String::from("Conversion error"))?)
222 }
223 }
224
225 pub fn read_ue_max<U: TryFrom<u32>>(&mut self, max: u32) -> Result<U, String> {
226 self.read_ue_bounded(0, max)
227 }
228
229 pub fn read_se<U: TryFrom<i32>>(&mut self) -> Result<U, String> {
234 let ue = self.read_ue::<u32>()? as i32;
235
236 if ue % 2 == 0 {
237 Ok(U::try_from(-(ue / 2)).map_err(|_| String::from("Conversion error"))?)
238 } else {
239 Ok(U::try_from(ue / 2 + 1).map_err(|_| String::from("Conversion error"))?)
240 }
241 }
242
243 pub fn read_se_bounded<U: TryFrom<i32>>(&mut self, min: i32, max: i32) -> Result<U, String> {
244 let se = self.read_se()?;
245 if se < min || se > max {
246 Err(format!(
247 "Value out of bounds, expected between {}-{}, got {}",
248 min, max, se
249 ))
250 } else {
251 Ok(U::try_from(se).map_err(|_| String::from("Conversion error"))?)
252 }
253 }
254
255 pub fn read_le<U: TryFrom<u32>>(&mut self, num_bits: u8) -> Result<U, String> {
257 let mut t = 0;
258
259 for i in 0..num_bits {
260 let byte = self.read_bits_aligned::<u32>(8)?;
261 t += byte << (i * 8)
262 }
263
264 Ok(U::try_from(t).map_err(|_| String::from("Conversion error"))?)
265 }
266
267 pub fn position(&self) -> u64 {
269 self.position
270 }
271
272 fn get_byte(&mut self) -> Result<u8, GetByteError> {
273 let mut buf = [0u8; 1];
274 self.data.read_exact(&mut buf).map_err(|_| GetByteError::OutOfBits)?;
275 Ok(buf[0])
276 }
277
278 fn move_to_next_byte(&mut self) -> Result<(), GetByteError> {
279 let mut byte = self.get_byte()?;
280
281 if self.needs_epb {
282 if self.prev_two_bytes == 0 && byte == 0x03 {
283 self.num_epb += 1;
285 byte = self.get_byte()?;
287 self.prev_two_bytes = 0xffff;
289 }
290 self.prev_two_bytes = (self.prev_two_bytes << 8) | u16::from(byte);
291 }
292
293 self.num_remaining_bits_in_curr_byte = 8;
294 self.curr_byte = byte;
295 Ok(())
296 }
297}
298
299pub struct IvfIterator<'a> {
301 cursor: Cursor<&'a [u8]>,
302}
303
304impl<'a> IvfIterator<'a> {
305 pub fn new(data: &'a [u8]) -> Self {
306 let mut cursor = Cursor::new(data);
307
308 cursor.seek(std::io::SeekFrom::Start(32)).unwrap();
310
311 Self { cursor }
312 }
313}
314
315impl<'a> Iterator for IvfIterator<'a> {
316 type Item = &'a [u8];
317
318 fn next(&mut self) -> Option<Self::Item> {
319 let mut len_buf = [0u8; 4];
321 self.cursor.read_exact(&mut len_buf).ok()?;
322 let len = ((len_buf[3] as usize) << 24)
323 | ((len_buf[2] as usize) << 16)
324 | ((len_buf[1] as usize) << 8)
325 | (len_buf[0] as usize);
326
327 self.cursor.seek(std::io::SeekFrom::Current(8)).ok()?;
329
330 let start = self.cursor.position() as usize;
331 let _ = self.cursor.seek(std::io::SeekFrom::Current(len as i64)).ok()?;
332 let end = self.cursor.position() as usize;
333
334 Some(&self.cursor.get_ref()[start..end])
335 }
336}
337
338pub struct IvfFileHeader {
340 pub magic: [u8; 4],
341 pub version: u16,
342 pub header_size: u16,
343 pub codec: [u8; 4],
344 pub width: u16,
345 pub height: u16,
346 pub framerate: u32,
347 pub timescale: u32,
348 pub frame_count: u32,
349 pub unused: u32,
350}
351
352impl Default for IvfFileHeader {
353 fn default() -> Self {
354 Self {
355 magic: Self::MAGIC,
356 version: 0,
357 header_size: 32,
358 codec: Self::CODEC_VP9,
359 width: 320,
360 height: 240,
361 framerate: 1,
362 timescale: 1000,
363 frame_count: 1,
364 unused: Default::default(),
365 }
366 }
367}
368
369impl IvfFileHeader {
370 pub const MAGIC: [u8; 4] = *b"DKIF";
371 pub const CODEC_VP8: [u8; 4] = *b"VP80";
372 pub const CODEC_VP9: [u8; 4] = *b"VP90";
373 pub const CODEC_AV1: [u8; 4] = *b"AV01";
374
375 pub fn new(codec: [u8; 4], width: u16, height: u16, framerate: u32, frame_count: u32) -> Self {
376 let default = Self::default();
377
378 Self {
379 codec,
380 width,
381 height,
382 framerate: framerate * default.timescale,
383 frame_count,
384 ..default
385 }
386 }
387}
388
389impl IvfFileHeader {
390 pub fn writo_into(&self, writer: &mut impl std::io::Write) -> std::io::Result<()> {
392 writer.write_all(&self.magic)?;
393 writer.write_all(&self.version.to_le_bytes())?;
394 writer.write_all(&self.header_size.to_le_bytes())?;
395 writer.write_all(&self.codec)?;
396 writer.write_all(&self.width.to_le_bytes())?;
397 writer.write_all(&self.height.to_le_bytes())?;
398 writer.write_all(&self.framerate.to_le_bytes())?;
399 writer.write_all(&self.timescale.to_le_bytes())?;
400 writer.write_all(&self.frame_count.to_le_bytes())?;
401 writer.write_all(&self.unused.to_le_bytes())?;
402
403 Ok(())
404 }
405}
406
407pub struct IvfFrameHeader {
409 pub frame_size: u32,
410 pub timestamp: u64,
411}
412
413impl IvfFrameHeader {
414 pub fn writo_into(&self, writer: &mut impl std::io::Write) -> std::io::Result<()> {
416 writer.write_all(&self.frame_size.to_le_bytes())?;
417 writer.write_all(&self.timestamp.to_le_bytes())?;
418 Ok(())
419 }
420}
421
422pub struct NalIterator<'a, Nalu>(Cursor<&'a [u8]>, PhantomData<Nalu>);
424
425impl<'a, Nalu> NalIterator<'a, Nalu> {
426 pub fn new(stream: &'a [u8]) -> Self {
427 Self(Cursor::new(stream), PhantomData)
428 }
429}
430
431impl<'a> Iterator for NalIterator<'a, H264Nalu<'a>> {
432 type Item = Cow<'a, [u8]>;
433
434 fn next(&mut self) -> Option<Self::Item> {
435 H264Nalu::next(&mut self.0).map(|n| n.data).ok()
436 }
437}
438
439#[derive(Debug)]
440pub enum BitWriterError {
441 InvalidBitCount,
442 Io(std::io::Error),
443}
444
445impl fmt::Display for BitWriterError {
446 fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
447 match self {
448 BitWriterError::InvalidBitCount => write!(f, "invalid bit count"),
449 BitWriterError::Io(x) => write!(f, "{}", x.to_string()),
450 }
451 }
452}
453
454impl From<std::io::Error> for BitWriterError {
455 fn from(err: std::io::Error) -> Self {
456 BitWriterError::Io(err)
457 }
458}
459
460pub type BitWriterResult<T> = std::result::Result<T, BitWriterError>;
461
462pub struct BitWriter<W: Write> {
463 out: W,
464 nth_bit: u8,
465 curr_byte: u8,
466}
467
468impl<W: Write> BitWriter<W> {
469 pub fn new(writer: W) -> Self {
470 Self {
471 out: writer,
472 curr_byte: 0,
473 nth_bit: 0,
474 }
475 }
476
477 pub fn write_f<T: Into<u32>>(&mut self, bits: usize, value: T) -> BitWriterResult<usize> {
479 let value = value.into();
480
481 if bits > 32 {
482 return Err(BitWriterError::InvalidBitCount);
483 }
484
485 let mut written = 0;
486 for bit in (0..bits).rev() {
487 let bit = (1 << bit) as u32;
488
489 self.write_bit((value & bit) == bit)?;
490 written += 1;
491 }
492
493 Ok(written)
494 }
495
496 pub fn write_bit(&mut self, bit: bool) -> BitWriterResult<()> {
498 self.curr_byte |= (bit as u8) << (7u8 - self.nth_bit);
499 self.nth_bit += 1;
500
501 if self.nth_bit == 8 {
502 self.out.write_all(&[self.curr_byte])?;
503 self.nth_bit = 0;
504 self.curr_byte = 0;
505 }
506
507 Ok(())
508 }
509
510 pub fn flush(&mut self) -> BitWriterResult<u8> {
513 let mut num_trailing_bits = 0;
514 if self.nth_bit != 0 {
515 self.out.write_all(&[self.curr_byte])?;
516 num_trailing_bits = 8 - self.nth_bit;
517 self.nth_bit = 0;
518 self.curr_byte = 0;
519 }
520
521 self.out.flush()?;
522 Ok(num_trailing_bits)
523 }
524
525 pub fn has_data_pending(&self) -> bool {
527 self.nth_bit != 0
528 }
529
530 pub(crate) fn inner(&self) -> &W {
531 &self.out
532 }
533
534 pub(crate) fn inner_mut(&mut self) -> &mut W {
535 &mut self.out
536 }
537}
538
539impl<W: Write> Drop for BitWriter<W> {
540 fn drop(&mut self) {
541 if let Err(e) = self.flush() {
542 log::error!("Unable to flush bits {e:?}");
543 }
544 }
545}