Skip to main content

moq_vaapi/
bitstream_utils.rs

1// Copyright 2024 The ChromiumOS Authors
2// Use of this source code is governed by a BSD-style license that can be
3// found in the LICENSE file.
4
5use std::borrow::Cow;
6use std::fmt;
7use std::io::Cursor;
8use std::io::Read;
9use std::io::Seek;
10use std::io::SeekFrom;
11use std::io::Write;
12use std::marker::PhantomData;
13
14use crate::codec::h264::parser::Nalu as H264Nalu;
15
16/// A bit reader for codec bitstreams. It properly handles emulation-prevention
17/// bytes and stop bits for H264.
18#[derive(Clone)]
19pub(crate) struct BitReader<'a> {
20	/// A reference into the next unread byte in the stream.
21	data: Cursor<&'a [u8]>,
22	/// Contents of the current byte. First unread bit starting at position 8 -
23	/// num_remaining_bits_in_curr_bytes.
24	curr_byte: u8,
25	/// Number of bits remaining in `curr_byte`
26	num_remaining_bits_in_curr_byte: usize,
27	/// Used in emulation prevention byte detection.
28	prev_two_bytes: u16,
29	/// Number of emulation prevention bytes (i.e. 0x000003) we found.
30	num_epb: usize,
31	/// Whether or not we need emulation prevention logic.
32	needs_epb: bool,
33	/// How many bits have been read so far.
34	position: u64,
35}
36
37#[derive(Debug)]
38pub(crate) enum GetByteError {
39	OutOfBits,
40}
41
42impl fmt::Display for GetByteError {
43	fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
44		write!(f, "reader ran out of bits")
45	}
46}
47
48#[derive(Debug)]
49pub(crate) enum ReadBitsError {
50	TooManyBitsRequested(usize),
51	GetByte(GetByteError),
52	ConversionFailed,
53}
54
55impl fmt::Display for ReadBitsError {
56	fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
57		match self {
58			ReadBitsError::TooManyBitsRequested(bits) => {
59				write!(f, "more than 31 ({}) bits were requested", bits)
60			}
61			ReadBitsError::GetByte(_) => write!(f, "failed to advance the current byte"),
62			ReadBitsError::ConversionFailed => {
63				write!(f, "failed to convert read input to target type")
64			}
65		}
66	}
67}
68
69impl From<GetByteError> for ReadBitsError {
70	fn from(err: GetByteError) -> Self {
71		ReadBitsError::GetByte(err)
72	}
73}
74
75impl<'a> BitReader<'a> {
76	pub fn new(data: &'a [u8], needs_epb: bool) -> Self {
77		Self {
78			data: Cursor::new(data),
79			curr_byte: Default::default(),
80			num_remaining_bits_in_curr_byte: Default::default(),
81			prev_two_bytes: 0xffff,
82			num_epb: Default::default(),
83			needs_epb: needs_epb,
84			position: 0,
85		}
86	}
87
88	/// Read a single bit from the stream.
89	pub fn read_bit(&mut self) -> Result<bool, String> {
90		let bit = self.read_bits::<u32>(1)?;
91		match bit {
92			1 => Ok(true),
93			0 => Ok(false),
94			_ => panic!("Unexpected value {}", bit),
95		}
96	}
97
98	/// Read up to 31 bits from the stream. Note that we don't want to read 32
99	/// bits even though we're returning a u32 because that would break the
100	/// read_bits_signed() function. 31 bits should be overkill for compressed
101	/// header parsing anyway.
102	pub fn read_bits<U: TryFrom<u32>>(&mut self, num_bits: usize) -> Result<U, String> {
103		if num_bits > 31 {
104			return Err(ReadBitsError::TooManyBitsRequested(num_bits).to_string());
105		}
106
107		let mut bits_left = num_bits;
108		let mut out = 0u32;
109
110		while self.num_remaining_bits_in_curr_byte < bits_left {
111			out |= (self.curr_byte as u32) << (bits_left - self.num_remaining_bits_in_curr_byte);
112			bits_left -= self.num_remaining_bits_in_curr_byte;
113			self.move_to_next_byte().map_err(|err| err.to_string())?;
114		}
115
116		out |= (self.curr_byte >> (self.num_remaining_bits_in_curr_byte - bits_left)) as u32;
117		out &= (1 << num_bits) - 1;
118		self.num_remaining_bits_in_curr_byte -= bits_left;
119		self.position += num_bits as u64;
120
121		U::try_from(out).map_err(|_| ReadBitsError::ConversionFailed.to_string())
122	}
123
124	/// Reads a two's complement signed integer of length |num_bits|.
125	pub fn read_bits_signed<U: TryFrom<i32>>(&mut self, num_bits: usize) -> Result<U, String> {
126		let mut out: i32 = self
127			.read_bits::<u32>(num_bits)?
128			.try_into()
129			.map_err(|_| ReadBitsError::ConversionFailed.to_string())?;
130		if out >> (num_bits - 1) != 0 {
131			out |= -1i32 ^ ((1 << num_bits) - 1);
132		}
133
134		U::try_from(out).map_err(|_| ReadBitsError::ConversionFailed.to_string())
135	}
136
137	/// Reads an unsigned integer from the stream and checks if the stream is byte aligned.
138	pub fn read_bits_aligned<U: TryFrom<u32>>(&mut self, num_bits: usize) -> Result<U, String> {
139		if self.num_remaining_bits_in_curr_byte % 8 != 0 {
140			return Err("Attempted unaligned read_le()".into());
141		}
142
143		Ok(self.read_bits(num_bits).map_err(|err| err.to_string())?)
144	}
145
146	/// Skip `num_bits` bits from the stream.
147	pub fn skip_bits(&mut self, mut num_bits: usize) -> Result<(), String> {
148		while num_bits > 0 {
149			let n = std::cmp::min(num_bits, 31);
150			self.read_bits::<u32>(n)?;
151			num_bits -= n;
152		}
153
154		Ok(())
155	}
156
157	/// Returns the amount of bits left in the stream
158	pub fn num_bits_left(&mut self) -> usize {
159		let cur_pos = self.data.position();
160		// This should always be safe to unwrap.
161		let end_pos = self.data.seek(SeekFrom::End(0)).unwrap();
162		let _ = self.data.seek(SeekFrom::Start(cur_pos));
163		((end_pos - cur_pos) as usize) * 8 + self.num_remaining_bits_in_curr_byte
164	}
165
166	/// Returns the number of emulation-prevention bytes read so far.
167	pub fn num_epb(&self) -> usize {
168		self.num_epb
169	}
170
171	/// Whether the stream still has RBSP data. Implements more_rbsp_data(). See
172	/// the spec for more details.
173	pub fn has_more_rsbp_data(&mut self) -> bool {
174		if self.num_remaining_bits_in_curr_byte == 0 && self.move_to_next_byte().is_err() {
175			// no more data at all in the rbsp
176			return false;
177		}
178
179		// If the next bit is the stop bit, then we should only see unset bits
180		// until the end of the data.
181		if (self.curr_byte & ((1 << (self.num_remaining_bits_in_curr_byte - 1)) - 1)) != 0 {
182			return true;
183		}
184
185		let mut buf = [0u8; 1];
186		let orig_pos = self.data.position();
187		while let Ok(_) = self.data.read_exact(&mut buf) {
188			if buf[0] != 0 {
189				self.data.set_position(orig_pos);
190				return true;
191			}
192		}
193		false
194	}
195
196	/// Reads an Unsigned Exponential golomb coding number from the next bytes in the
197	/// bitstream. This may advance the state of position within the bitstream even if the
198	/// read operation is unsuccessful. See H264 Annex B specification 9.1 for details.
199	pub fn read_ue<U: TryFrom<u32>>(&mut self) -> Result<U, String> {
200		let mut num_bits = 0;
201
202		while self.read_bits::<u32>(1)? == 0 {
203			num_bits += 1;
204			if num_bits > 31 {
205				return Err("invalid stream".into());
206			}
207		}
208
209		let value = ((1u32 << num_bits) - 1)
210			.checked_add(self.read_bits::<u32>(num_bits)?)
211			.ok_or::<String>("read number cannot fit in 32 bits".into())?;
212
213		U::try_from(value).map_err(|_| "conversion error".into())
214	}
215
216	pub fn read_ue_bounded<U: TryFrom<u32>>(&mut self, min: u32, max: u32) -> Result<U, String> {
217		let ue = self.read_ue()?;
218		if ue > max || ue < min {
219			Err(format!("Value out of bounds: expected {} - {}, got {}", min, max, ue))
220		} else {
221			Ok(U::try_from(ue).map_err(|_| String::from("Conversion error"))?)
222		}
223	}
224
225	pub fn read_ue_max<U: TryFrom<u32>>(&mut self, max: u32) -> Result<U, String> {
226		self.read_ue_bounded(0, max)
227	}
228
229	/// Reads a signed exponential golomb coding number. Instead of using two's
230	/// complement, this scheme maps even integers to positive numbers and odd
231	/// integers to negative numbers. The least significant bit indicates the
232	/// sign. See H264 Annex B specification 9.1.1 for details.
233	pub fn read_se<U: TryFrom<i32>>(&mut self) -> Result<U, String> {
234		let ue = self.read_ue::<u32>()? as i32;
235
236		if ue % 2 == 0 {
237			Ok(U::try_from(-(ue / 2)).map_err(|_| String::from("Conversion error"))?)
238		} else {
239			Ok(U::try_from(ue / 2 + 1).map_err(|_| String::from("Conversion error"))?)
240		}
241	}
242
243	pub fn read_se_bounded<U: TryFrom<i32>>(&mut self, min: i32, max: i32) -> Result<U, String> {
244		let se = self.read_se()?;
245		if se < min || se > max {
246			Err(format!(
247				"Value out of bounds, expected between {}-{}, got {}",
248				min, max, se
249			))
250		} else {
251			Ok(U::try_from(se).map_err(|_| String::from("Conversion error"))?)
252		}
253	}
254
255	/// Read little endian multi-byte integer.
256	pub fn read_le<U: TryFrom<u32>>(&mut self, num_bits: u8) -> Result<U, String> {
257		let mut t = 0;
258
259		for i in 0..num_bits {
260			let byte = self.read_bits_aligned::<u32>(8)?;
261			t += byte << (i * 8)
262		}
263
264		Ok(U::try_from(t).map_err(|_| String::from("Conversion error"))?)
265	}
266
267	/// Return the position of this bitstream in bits.
268	pub fn position(&self) -> u64 {
269		self.position
270	}
271
272	fn get_byte(&mut self) -> Result<u8, GetByteError> {
273		let mut buf = [0u8; 1];
274		self.data.read_exact(&mut buf).map_err(|_| GetByteError::OutOfBits)?;
275		Ok(buf[0])
276	}
277
278	fn move_to_next_byte(&mut self) -> Result<(), GetByteError> {
279		let mut byte = self.get_byte()?;
280
281		if self.needs_epb {
282			if self.prev_two_bytes == 0 && byte == 0x03 {
283				// We found an epb
284				self.num_epb += 1;
285				// Read another byte
286				byte = self.get_byte()?;
287				// We need another 3 bytes before another epb can happen.
288				self.prev_two_bytes = 0xffff;
289			}
290			self.prev_two_bytes = (self.prev_two_bytes << 8) | u16::from(byte);
291		}
292
293		self.num_remaining_bits_in_curr_byte = 8;
294		self.curr_byte = byte;
295		Ok(())
296	}
297}
298
299/// Iterator over IVF packets.
300pub struct IvfIterator<'a> {
301	cursor: Cursor<&'a [u8]>,
302}
303
304impl<'a> IvfIterator<'a> {
305	pub fn new(data: &'a [u8]) -> Self {
306		let mut cursor = Cursor::new(data);
307
308		// Skip the IVH header entirely.
309		cursor.seek(std::io::SeekFrom::Start(32)).unwrap();
310
311		Self { cursor }
312	}
313}
314
315impl<'a> Iterator for IvfIterator<'a> {
316	type Item = &'a [u8];
317
318	fn next(&mut self) -> Option<Self::Item> {
319		// Make sure we have a header.
320		let mut len_buf = [0u8; 4];
321		self.cursor.read_exact(&mut len_buf).ok()?;
322		let len = ((len_buf[3] as usize) << 24)
323			| ((len_buf[2] as usize) << 16)
324			| ((len_buf[1] as usize) << 8)
325			| (len_buf[0] as usize);
326
327		// Skip PTS.
328		self.cursor.seek(std::io::SeekFrom::Current(8)).ok()?;
329
330		let start = self.cursor.position() as usize;
331		let _ = self.cursor.seek(std::io::SeekFrom::Current(len as i64)).ok()?;
332		let end = self.cursor.position() as usize;
333
334		Some(&self.cursor.get_ref()[start..end])
335	}
336}
337
338/// Helper struct for synthesizing IVF file header
339pub struct IvfFileHeader {
340	pub magic: [u8; 4],
341	pub version: u16,
342	pub header_size: u16,
343	pub codec: [u8; 4],
344	pub width: u16,
345	pub height: u16,
346	pub framerate: u32,
347	pub timescale: u32,
348	pub frame_count: u32,
349	pub unused: u32,
350}
351
352impl Default for IvfFileHeader {
353	fn default() -> Self {
354		Self {
355			magic: Self::MAGIC,
356			version: 0,
357			header_size: 32,
358			codec: Self::CODEC_VP9,
359			width: 320,
360			height: 240,
361			framerate: 1,
362			timescale: 1000,
363			frame_count: 1,
364			unused: Default::default(),
365		}
366	}
367}
368
369impl IvfFileHeader {
370	pub const MAGIC: [u8; 4] = *b"DKIF";
371	pub const CODEC_VP8: [u8; 4] = *b"VP80";
372	pub const CODEC_VP9: [u8; 4] = *b"VP90";
373	pub const CODEC_AV1: [u8; 4] = *b"AV01";
374
375	pub fn new(codec: [u8; 4], width: u16, height: u16, framerate: u32, frame_count: u32) -> Self {
376		let default = Self::default();
377
378		Self {
379			codec,
380			width,
381			height,
382			framerate: framerate * default.timescale,
383			frame_count,
384			..default
385		}
386	}
387}
388
389impl IvfFileHeader {
390	/// Writes header into writer
391	pub fn writo_into(&self, writer: &mut impl std::io::Write) -> std::io::Result<()> {
392		writer.write_all(&self.magic)?;
393		writer.write_all(&self.version.to_le_bytes())?;
394		writer.write_all(&self.header_size.to_le_bytes())?;
395		writer.write_all(&self.codec)?;
396		writer.write_all(&self.width.to_le_bytes())?;
397		writer.write_all(&self.height.to_le_bytes())?;
398		writer.write_all(&self.framerate.to_le_bytes())?;
399		writer.write_all(&self.timescale.to_le_bytes())?;
400		writer.write_all(&self.frame_count.to_le_bytes())?;
401		writer.write_all(&self.unused.to_le_bytes())?;
402
403		Ok(())
404	}
405}
406
407/// Helper struct for synthesizing IVF frame header
408pub struct IvfFrameHeader {
409	pub frame_size: u32,
410	pub timestamp: u64,
411}
412
413impl IvfFrameHeader {
414	/// Writes header into writer
415	pub fn writo_into(&self, writer: &mut impl std::io::Write) -> std::io::Result<()> {
416		writer.write_all(&self.frame_size.to_le_bytes())?;
417		writer.write_all(&self.timestamp.to_le_bytes())?;
418		Ok(())
419	}
420}
421
422/// Iterator NALUs in a bitstream.
423pub struct NalIterator<'a, Nalu>(Cursor<&'a [u8]>, PhantomData<Nalu>);
424
425impl<'a, Nalu> NalIterator<'a, Nalu> {
426	pub fn new(stream: &'a [u8]) -> Self {
427		Self(Cursor::new(stream), PhantomData)
428	}
429}
430
431impl<'a> Iterator for NalIterator<'a, H264Nalu<'a>> {
432	type Item = Cow<'a, [u8]>;
433
434	fn next(&mut self) -> Option<Self::Item> {
435		H264Nalu::next(&mut self.0).map(|n| n.data).ok()
436	}
437}
438
439#[derive(Debug)]
440pub enum BitWriterError {
441	InvalidBitCount,
442	Io(std::io::Error),
443}
444
445impl fmt::Display for BitWriterError {
446	fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
447		match self {
448			BitWriterError::InvalidBitCount => write!(f, "invalid bit count"),
449			BitWriterError::Io(x) => write!(f, "{}", x.to_string()),
450		}
451	}
452}
453
454impl From<std::io::Error> for BitWriterError {
455	fn from(err: std::io::Error) -> Self {
456		BitWriterError::Io(err)
457	}
458}
459
460pub type BitWriterResult<T> = std::result::Result<T, BitWriterError>;
461
462pub struct BitWriter<W: Write> {
463	out: W,
464	nth_bit: u8,
465	curr_byte: u8,
466}
467
468impl<W: Write> BitWriter<W> {
469	pub fn new(writer: W) -> Self {
470		Self {
471			out: writer,
472			curr_byte: 0,
473			nth_bit: 0,
474		}
475	}
476
477	/// Writes fixed bit size integer (up to 32 bit)
478	pub fn write_f<T: Into<u32>>(&mut self, bits: usize, value: T) -> BitWriterResult<usize> {
479		let value = value.into();
480
481		if bits > 32 {
482			return Err(BitWriterError::InvalidBitCount);
483		}
484
485		let mut written = 0;
486		for bit in (0..bits).rev() {
487			let bit = (1 << bit) as u32;
488
489			self.write_bit((value & bit) == bit)?;
490			written += 1;
491		}
492
493		Ok(written)
494	}
495
496	/// Takes a single bit that will be outputed to [`std::io::Write`]
497	pub fn write_bit(&mut self, bit: bool) -> BitWriterResult<()> {
498		self.curr_byte |= (bit as u8) << (7u8 - self.nth_bit);
499		self.nth_bit += 1;
500
501		if self.nth_bit == 8 {
502			self.out.write_all(&[self.curr_byte])?;
503			self.nth_bit = 0;
504			self.curr_byte = 0;
505		}
506
507		Ok(())
508	}
509
510	/// Immediately outputs any cached bits to [`std::io::Write`]
511	/// and returns the number of trailing bits in the last byte.
512	pub fn flush(&mut self) -> BitWriterResult<u8> {
513		let mut num_trailing_bits = 0;
514		if self.nth_bit != 0 {
515			self.out.write_all(&[self.curr_byte])?;
516			num_trailing_bits = 8 - self.nth_bit;
517			self.nth_bit = 0;
518			self.curr_byte = 0;
519		}
520
521		self.out.flush()?;
522		Ok(num_trailing_bits)
523	}
524
525	/// Returns `true` if ['Self`] hold data that wasn't written to [`std::io::Write`]
526	pub fn has_data_pending(&self) -> bool {
527		self.nth_bit != 0
528	}
529
530	pub(crate) fn inner(&self) -> &W {
531		&self.out
532	}
533
534	pub(crate) fn inner_mut(&mut self) -> &mut W {
535		&mut self.out
536	}
537}
538
539impl<W: Write> Drop for BitWriter<W> {
540	fn drop(&mut self) {
541		if let Err(e) = self.flush() {
542			log::error!("Unable to flush bits {e:?}");
543		}
544	}
545}