use std::fmt;
#[derive(Debug, Clone, PartialEq, Eq)]
pub enum BitstreamFilterError {
BufferTooShort {
needed: usize,
available: usize,
},
InvalidLengthPrefix {
offset: usize,
claimed: usize,
available: usize,
},
MalformedObuHeader {
offset: usize,
},
MalformedSequenceHeader,
EmptyPacket,
UnknownNalType(u8),
InvalidLengthPrefixSize(u8),
}
impl fmt::Display for BitstreamFilterError {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
match self {
Self::BufferTooShort { needed, available } => {
write!(
f,
"buffer too short: needed {needed}, available {available}"
)
}
Self::InvalidLengthPrefix {
offset,
claimed,
available,
} => {
write!(
f,
"invalid length prefix at offset {offset}: claims {claimed} bytes but only {available} remain"
)
}
Self::MalformedObuHeader { offset } => {
write!(f, "malformed OBU header at offset {offset}")
}
Self::MalformedSequenceHeader => write!(f, "malformed AV1 sequence header"),
Self::EmptyPacket => write!(f, "packet is empty"),
Self::UnknownNalType(t) => write!(f, "unknown NAL unit type: {t}"),
Self::InvalidLengthPrefixSize(s) => {
write!(f, "invalid length prefix size: {s} (must be 1, 2, or 4)")
}
}
}
}
impl std::error::Error for BitstreamFilterError {}
pub type BitstreamResult<T> = Result<T, BitstreamFilterError>;
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
pub enum LengthPrefixSize {
One = 1,
Two = 2,
Four = 4,
}
impl LengthPrefixSize {
pub fn from_raw(raw: u8) -> BitstreamResult<Self> {
match raw {
1 => Ok(Self::One),
2 => Ok(Self::Two),
4 => Ok(Self::Four),
other => Err(BitstreamFilterError::InvalidLengthPrefixSize(other)),
}
}
pub fn as_usize(self) -> usize {
self as usize
}
}
const START_CODE_3: [u8; 3] = [0x00, 0x00, 0x01];
const START_CODE_4: [u8; 4] = [0x00, 0x00, 0x00, 0x01];
pub fn split_annexb(data: &[u8]) -> Vec<&[u8]> {
let mut nals: Vec<&[u8]> = Vec::new();
let mut start = 0usize;
let len = data.len();
if len >= 4 && data[..4] == START_CODE_4 {
start = 4;
} else if len >= 3 && data[..3] == START_CODE_3 {
start = 3;
}
let mut i = start;
while i + 2 < len {
if data[i] == 0x00 && data[i + 1] == 0x00 {
if i + 3 < len && data[i + 2] == 0x00 && data[i + 3] == 0x01 {
let nal = &data[start..i];
if !nal.is_empty() {
nals.push(nal);
}
i += 4;
start = i;
continue;
} else if data[i + 2] == 0x01 {
let nal = &data[start..i];
if !nal.is_empty() {
nals.push(nal);
}
i += 3;
start = i;
continue;
}
}
i += 1;
}
let tail = &data[start..];
if !tail.is_empty() {
nals.push(tail);
}
nals
}
pub fn annexb_to_avcc(data: &[u8], prefix_size: LengthPrefixSize) -> BitstreamResult<Vec<u8>> {
if data.is_empty() {
return Err(BitstreamFilterError::EmptyPacket);
}
let nals = split_annexb(data);
let prefix_bytes = prefix_size.as_usize();
let total: usize = nals.iter().map(|n| prefix_bytes + n.len()).sum();
let mut out = Vec::with_capacity(total);
for nal in nals {
let nal_len = nal.len();
match prefix_size {
LengthPrefixSize::One => {
out.push(nal_len as u8);
}
LengthPrefixSize::Two => {
out.extend_from_slice(&(nal_len as u16).to_be_bytes());
}
LengthPrefixSize::Four => {
out.extend_from_slice(&(nal_len as u32).to_be_bytes());
}
}
out.extend_from_slice(nal);
}
Ok(out)
}
pub fn avcc_to_annexb(data: &[u8], prefix_size: LengthPrefixSize) -> BitstreamResult<Vec<u8>> {
if data.is_empty() {
return Err(BitstreamFilterError::EmptyPacket);
}
let prefix_bytes = prefix_size.as_usize();
let mut out = Vec::with_capacity(data.len() + data.len() / 4);
let mut offset = 0usize;
while offset < data.len() {
if offset + prefix_bytes > data.len() {
return Err(BitstreamFilterError::BufferTooShort {
needed: offset + prefix_bytes,
available: data.len(),
});
}
let nal_len = read_be_uint(&data[offset..offset + prefix_bytes], prefix_bytes);
offset += prefix_bytes;
let remaining = data.len() - offset;
if nal_len > remaining {
return Err(BitstreamFilterError::InvalidLengthPrefix {
offset: offset - prefix_bytes,
claimed: nal_len,
available: remaining,
});
}
out.extend_from_slice(&START_CODE_4);
out.extend_from_slice(&data[offset..offset + nal_len]);
offset += nal_len;
}
Ok(out)
}
fn read_be_uint(bytes: &[u8], n: usize) -> usize {
match n {
1 => bytes[0] as usize,
2 => u16::from_be_bytes([bytes[0], bytes[1]]) as usize,
4 => u32::from_be_bytes([bytes[0], bytes[1], bytes[2], bytes[3]]) as usize,
_ => 0,
}
}
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
pub enum H264NalType {
NonIdrSlice,
IdrSlice,
Sei,
Sps,
Pps,
Aud,
EndOfSeq,
EndOfStream,
FillerData,
Other(u8),
}
impl H264NalType {
pub fn from_nal_byte(byte: u8) -> Self {
match byte & 0x1F {
1 => Self::NonIdrSlice,
5 => Self::IdrSlice,
6 => Self::Sei,
7 => Self::Sps,
8 => Self::Pps,
9 => Self::Aud,
10 => Self::EndOfSeq,
11 => Self::EndOfStream,
12 => Self::FillerData,
t => Self::Other(t),
}
}
}
#[derive(Debug, Clone)]
pub struct NalUnit<'a> {
pub nal_type: H264NalType,
pub data: &'a [u8],
}
impl<'a> NalUnit<'a> {
pub fn from_raw(data: &'a [u8]) -> Option<Self> {
let first = *data.first()?;
Some(Self {
nal_type: H264NalType::from_nal_byte(first),
data,
})
}
}
pub fn extract_sps(data: &[u8]) -> Vec<NalUnit<'_>> {
split_annexb(data)
.into_iter()
.filter_map(NalUnit::from_raw)
.filter(|n| n.nal_type == H264NalType::Sps)
.collect()
}
pub fn extract_pps(data: &[u8]) -> Vec<NalUnit<'_>> {
split_annexb(data)
.into_iter()
.filter_map(NalUnit::from_raw)
.filter(|n| n.nal_type == H264NalType::Pps)
.collect()
}
pub fn extract_sps_pps(data: &[u8]) -> (Vec<Vec<u8>>, Vec<Vec<u8>>) {
let nals = split_annexb(data);
let mut sps_list = Vec::new();
let mut pps_list = Vec::new();
for nal_bytes in nals {
if let Some(nal) = NalUnit::from_raw(nal_bytes) {
match nal.nal_type {
H264NalType::Sps => sps_list.push(nal.data.to_vec()),
H264NalType::Pps => pps_list.push(nal.data.to_vec()),
_ => {}
}
}
}
(sps_list, pps_list)
}
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
pub enum Av1ObuType {
SequenceHeader,
TemporalDelimiter,
FrameHeader,
TileGroup,
Metadata,
Frame,
RedundantFrameHeader,
TileList,
Padding,
Reserved(u8),
}
impl Av1ObuType {
fn from_raw(raw: u8) -> Self {
match raw {
1 => Self::SequenceHeader,
2 => Self::TemporalDelimiter,
3 => Self::FrameHeader,
4 => Self::TileGroup,
5 => Self::Metadata,
6 => Self::Frame,
7 => Self::RedundantFrameHeader,
8 => Self::TileList,
15 => Self::Padding,
other => Self::Reserved(other),
}
}
}
#[derive(Debug, Clone)]
pub struct Av1Obu {
pub obu_type: Av1ObuType,
pub payload: Vec<u8>,
}
fn read_leb128(data: &[u8], offset: usize) -> BitstreamResult<(u64, usize)> {
let mut result: u64 = 0;
let mut shift = 0u32;
let mut consumed = 0usize;
loop {
if offset + consumed >= data.len() {
return Err(BitstreamFilterError::MalformedObuHeader { offset });
}
let byte = data[offset + consumed];
consumed += 1;
result |= ((byte & 0x7F) as u64) << shift;
if byte & 0x80 == 0 {
break;
}
shift += 7;
if shift >= 56 {
return Err(BitstreamFilterError::MalformedObuHeader { offset });
}
}
Ok((result, consumed))
}
pub fn split_av1_obus(data: &[u8]) -> BitstreamResult<Vec<Av1Obu>> {
if data.is_empty() {
return Err(BitstreamFilterError::EmptyPacket);
}
let mut obus = Vec::new();
let mut offset = 0usize;
let len = data.len();
while offset < len {
if offset >= len {
break;
}
let header_byte = data[offset];
let forbidden_bit = (header_byte >> 7) & 1;
if forbidden_bit != 0 {
return Err(BitstreamFilterError::MalformedObuHeader { offset });
}
let obu_type_raw = (header_byte >> 3) & 0x0F;
let obu_extension_flag = (header_byte >> 2) & 1;
let obu_has_size_field = (header_byte >> 1) & 1;
offset += 1;
if obu_extension_flag == 1 {
if offset >= len {
return Err(BitstreamFilterError::MalformedObuHeader { offset });
}
offset += 1;
}
let payload_len = if obu_has_size_field == 1 {
let (sz, consumed) = read_leb128(data, offset)?;
offset += consumed;
sz as usize
} else {
len - offset
};
if offset + payload_len > len {
return Err(BitstreamFilterError::InvalidLengthPrefix {
offset,
claimed: payload_len,
available: len - offset,
});
}
let payload = data[offset..offset + payload_len].to_vec();
offset += payload_len;
obus.push(Av1Obu {
obu_type: Av1ObuType::from_raw(obu_type_raw),
payload,
});
}
Ok(obus)
}
#[derive(Debug, Clone, PartialEq, Eq)]
pub struct Av1SequenceHeader {
pub seq_profile: u8,
pub still_picture: bool,
pub reduced_still_picture_header: bool,
pub max_frame_width: u32,
pub max_frame_height: u32,
pub high_bitdepth: bool,
pub twelve_bit: bool,
pub mono_chrome: bool,
}
struct BitReader<'a> {
data: &'a [u8],
byte_offset: usize,
bit_offset: u8,
}
impl<'a> BitReader<'a> {
fn new(data: &'a [u8]) -> Self {
Self {
data,
byte_offset: 0,
bit_offset: 0,
}
}
fn read_bit(&mut self) -> BitstreamResult<u8> {
if self.byte_offset >= self.data.len() {
return Err(BitstreamFilterError::MalformedSequenceHeader);
}
let byte = self.data[self.byte_offset];
let bit = (byte >> (7 - self.bit_offset)) & 1;
self.bit_offset += 1;
if self.bit_offset == 8 {
self.bit_offset = 0;
self.byte_offset += 1;
}
Ok(bit)
}
fn read_bits(&mut self, n: u8) -> BitstreamResult<u32> {
let mut val = 0u32;
for _ in 0..n {
val = (val << 1) | self.read_bit()? as u32;
}
Ok(val)
}
fn u(&mut self, n: u8) -> BitstreamResult<u32> {
self.read_bits(n)
}
fn f(&mut self, n: u8) -> BitstreamResult<u32> {
self.read_bits(n)
}
}
pub fn parse_av1_sequence_header(payload: &[u8]) -> BitstreamResult<Av1SequenceHeader> {
let mut r = BitReader::new(payload);
let seq_profile = r.f(3)? as u8;
let still_picture = r.f(1)? != 0;
let reduced_still_picture_header = r.f(1)? != 0;
let (timing_info_present, decoder_model_info_present) = if reduced_still_picture_header {
(false, false)
} else {
let tip = r.f(1)? != 0;
let dmip = if tip {
r.u(32)?;
r.u(32)?;
let epi = r.f(1)?;
if epi != 0 {
let _ = read_uvlc(&mut r)?;
}
r.f(1)? != 0
} else {
false
};
if dmip {
let _ = r.u(5)?;
let _ = r.u(32)?;
let _ = r.u(9)?;
}
(tip, dmip)
};
let _ = timing_info_present;
let _ = decoder_model_info_present;
if !reduced_still_picture_header {
let op_cnt = r.u(5)?; for _ in 0..=op_cnt {
let _op_idc = r.u(12)?;
let _seq_level_idx = r.u(5)?;
let seq_tier = if r.u(5)? > 7 { r.u(1)? } else { 0 };
let _ = seq_tier;
if decoder_model_info_present {
let _decoder_model_present = r.u(1)?;
}
if !reduced_still_picture_header {
let _initial_display_delay_present = r.u(1)?;
if decoder_model_info_present {
let _initial_display_delay_minus_1 = r.u(4)?;
}
}
}
}
let fw_bits = r.u(4)? + 1;
let fh_bits = r.u(4)? + 1;
let max_frame_width = r.u(fw_bits as u8)? + 1;
let max_frame_height = r.u(fh_bits as u8)? + 1;
if !reduced_still_picture_header {
let frame_id_numbers_present = r.u(1)?;
if frame_id_numbers_present != 0 {
let _delta_frame_id_length = r.u(4)?;
let _additional_frame_id_length = r.u(3)?;
}
}
let _use_128 = r.u(1)?;
let _enable_filter_intra = r.u(1)?;
let _enable_intra_edge_filter = r.u(1)?;
let high_bitdepth = r.u(1)? != 0;
let twelve_bit = if seq_profile == 2 && high_bitdepth {
r.u(1)? != 0
} else {
false
};
let mono_chrome = if seq_profile == 1 {
false
} else {
r.u(1)? != 0
};
Ok(Av1SequenceHeader {
seq_profile,
still_picture,
reduced_still_picture_header,
max_frame_width,
max_frame_height,
high_bitdepth,
twelve_bit,
mono_chrome,
})
}
fn read_uvlc(r: &mut BitReader<'_>) -> BitstreamResult<u32> {
let mut leading_zeros = 0u32;
loop {
let bit = r.read_bit()?;
if bit != 0 {
break;
}
leading_zeros += 1;
if leading_zeros >= 32 {
return Err(BitstreamFilterError::MalformedSequenceHeader);
}
}
if leading_zeros == 0 {
return Ok(0);
}
let value = r.read_bits(leading_zeros as u8)?;
Ok((1 << leading_zeros) + value - 1)
}
pub fn find_av1_sequence_header(data: &[u8]) -> BitstreamResult<Option<Av1SequenceHeader>> {
let obus = split_av1_obus(data)?;
for obu in obus {
if obu.obu_type == Av1ObuType::SequenceHeader {
return parse_av1_sequence_header(&obu.payload).map(Some);
}
}
Ok(None)
}
pub fn remove_emulation_prevention(data: &[u8]) -> Vec<u8> {
let mut out = Vec::with_capacity(data.len());
let len = data.len();
let mut i = 0;
while i < len {
if i + 2 < len && data[i] == 0x00 && data[i + 1] == 0x00 && data[i + 2] == 0x03 {
out.push(0x00);
out.push(0x00);
i += 3; } else {
out.push(data[i]);
i += 1;
}
}
out
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_split_annexb_single_nal_4byte_startcode() {
let data = [0x00, 0x00, 0x00, 0x01, 0x67, 0xAB, 0xCD];
let nals = split_annexb(&data);
assert_eq!(nals.len(), 1);
assert_eq!(nals[0], &[0x67, 0xAB, 0xCD]);
}
#[test]
fn test_split_annexb_multiple_nals() {
let data = [
0x00, 0x00, 0x00, 0x01, 0x67, 0x11, 0x00, 0x00, 0x01, 0x68, 0x22, ];
let nals = split_annexb(&data);
assert_eq!(nals.len(), 2);
assert_eq!(nals[0], &[0x67, 0x11]);
assert_eq!(nals[1], &[0x68, 0x22]);
}
#[test]
fn test_annexb_to_avcc_roundtrip() {
let sps = [0x67u8, 0x42, 0x00, 0x1E];
let pps = [0x68u8, 0xCE, 0x38, 0x80];
let mut annexb = Vec::new();
annexb.extend_from_slice(&[0x00, 0x00, 0x00, 0x01]);
annexb.extend_from_slice(&sps);
annexb.extend_from_slice(&[0x00, 0x00, 0x00, 0x01]);
annexb.extend_from_slice(&pps);
let avcc = annexb_to_avcc(&annexb, LengthPrefixSize::Four).unwrap();
let back = avcc_to_annexb(&avcc, LengthPrefixSize::Four).unwrap();
let nals = split_annexb(&back);
assert_eq!(nals.len(), 2);
assert_eq!(nals[0], &sps);
assert_eq!(nals[1], &pps);
}
#[test]
fn test_avcc_to_annexb_two_byte_prefix() {
let nal = [0x65u8, 0x11, 0x22];
let mut avcc = Vec::new();
avcc.extend_from_slice(&(3u16).to_be_bytes());
avcc.extend_from_slice(&nal);
let result = avcc_to_annexb(&avcc, LengthPrefixSize::Two).unwrap();
assert_eq!(&result[..4], &[0x00, 0x00, 0x00, 0x01]);
assert_eq!(&result[4..], &nal);
}
#[test]
fn test_avcc_invalid_length_prefix_error() {
let mut avcc = Vec::new();
avcc.extend_from_slice(&(100u32).to_be_bytes());
avcc.extend_from_slice(&[0xAA, 0xBB]);
let err = avcc_to_annexb(&avcc, LengthPrefixSize::Four).unwrap_err();
assert!(matches!(
err,
BitstreamFilterError::InvalidLengthPrefix { .. }
));
}
#[test]
fn test_extract_sps_pps() {
let mut stream = Vec::new();
stream.extend_from_slice(&[0x00, 0x00, 0x00, 0x01, 0x67, 0x42, 0x00, 0x1E]);
stream.extend_from_slice(&[0x00, 0x00, 0x00, 0x01, 0x68, 0xCE]);
stream.extend_from_slice(&[0x00, 0x00, 0x00, 0x01, 0x65, 0x88]);
let (sps, pps) = extract_sps_pps(&stream);
assert_eq!(sps.len(), 1);
assert_eq!(pps.len(), 1);
assert_eq!(sps[0][0], 0x67);
assert_eq!(pps[0][0], 0x68);
}
#[test]
fn test_remove_emulation_prevention() {
let input = [0x00u8, 0x00, 0x03, 0x01, 0xFF];
let output = remove_emulation_prevention(&input);
assert_eq!(output, [0x00, 0x00, 0x01, 0xFF]);
}
#[test]
fn test_split_av1_obus_sequence_header() {
let payload = [0x00u8; 4]; let mut data = Vec::new();
data.push(0x0A); data.push(0x04);
data.extend_from_slice(&payload);
let obus = split_av1_obus(&data).unwrap();
assert_eq!(obus.len(), 1);
assert_eq!(obus[0].obu_type, Av1ObuType::SequenceHeader);
assert_eq!(obus[0].payload, payload);
}
#[test]
fn test_split_av1_obus_empty_error() {
let err = split_av1_obus(&[]).unwrap_err();
assert_eq!(err, BitstreamFilterError::EmptyPacket);
}
#[test]
fn test_split_av1_obus_multiple() {
let mut data = Vec::new();
data.push(0x12); data.push(0x00); data.push(0x22); data.push(0x02); data.push(0xAA);
data.push(0xBB);
let obus = split_av1_obus(&data).unwrap();
assert_eq!(obus.len(), 2);
assert_eq!(obus[0].obu_type, Av1ObuType::TemporalDelimiter);
assert_eq!(obus[1].obu_type, Av1ObuType::TileGroup);
assert_eq!(obus[1].payload, [0xAA, 0xBB]);
}
#[test]
fn test_leb128_multi_byte() {
let data = [0xACu8, 0x02];
let (val, consumed) = read_leb128(&data, 0).unwrap();
assert_eq!(val, 300);
assert_eq!(consumed, 2);
}
#[test]
fn test_empty_packet_error() {
assert_eq!(
annexb_to_avcc(&[], LengthPrefixSize::Four).unwrap_err(),
BitstreamFilterError::EmptyPacket
);
assert_eq!(
avcc_to_annexb(&[], LengthPrefixSize::Four).unwrap_err(),
BitstreamFilterError::EmptyPacket
);
}
#[test]
fn test_length_prefix_size_from_raw() {
assert_eq!(
LengthPrefixSize::from_raw(1).unwrap(),
LengthPrefixSize::One
);
assert_eq!(
LengthPrefixSize::from_raw(2).unwrap(),
LengthPrefixSize::Two
);
assert_eq!(
LengthPrefixSize::from_raw(4).unwrap(),
LengthPrefixSize::Four
);
assert!(LengthPrefixSize::from_raw(3).is_err());
}
}