use anyhow::{self};
use bytes::{Buf, Bytes};
pub const START_CODE: Bytes = Bytes::from_static(&[0, 0, 0, 1]);
pub struct NalIterator<'a, T: Buf + AsRef<[u8]> + 'a> {
buf: &'a mut T,
start: Option<usize>,
}
impl<'a, T: Buf + AsRef<[u8]> + 'a> NalIterator<'a, T> {
pub fn new(buf: &'a mut T) -> Self {
Self { buf, start: None }
}
pub fn flush(self) -> anyhow::Result<Option<Bytes>> {
let start = match self.start {
Some(start) => start,
None => {
let Some(start) = after_start_code(self.buf.as_ref())? else {
return Ok(None);
};
start
}
};
self.buf.advance(start);
let nal = self.buf.copy_to_bytes(self.buf.remaining());
Ok(Some(nal))
}
}
impl<'a, T: Buf + AsRef<[u8]> + 'a> Iterator for NalIterator<'a, T> {
type Item = anyhow::Result<Bytes>;
fn next(&mut self) -> Option<Self::Item> {
let start = match self.start {
Some(start) => start,
None => match after_start_code(self.buf.as_ref()).transpose()? {
Ok(start) => start,
Err(err) => return Some(Err(err)),
},
};
let (size, new_start) = find_start_code(&self.buf.as_ref()[start..])?;
self.buf.advance(start);
let nal = self.buf.copy_to_bytes(size);
self.start = Some(new_start);
Some(Ok(nal))
}
}
pub fn after_start_code(b: &[u8]) -> anyhow::Result<Option<usize>> {
if b.len() < 3 {
return Ok(None);
}
anyhow::ensure!(b[0] == 0, "missing Annex B start code");
anyhow::ensure!(b[1] == 0, "missing Annex B start code");
match b[2] {
0 if b.len() < 4 => Ok(None),
0 if b[3] != 1 => anyhow::bail!("missing Annex B start code"),
0 => Ok(Some(4)),
1 => Ok(Some(3)),
_ => anyhow::bail!("invalid Annex B start code"),
}
}
pub fn find_start_code(mut b: &[u8]) -> Option<(usize, usize)> {
let size = b.len();
while b.len() >= 3 {
match b[2] {
0 if b.len() >= 4 => match b[3] {
1 => match b[1] {
0 => match b[0] {
0 => return Some((size - b.len(), 4)),
_ => return Some((size - b.len() + 1, 3)),
},
_ => b = &b[4..],
},
0 => b = &b[1..],
_ => b = &b[4..],
},
0 => return None,
1 => match b[1] {
0 => match b[0] {
0 => return Some((size - b.len(), 3)),
_ => b = &b[3..],
},
_ => b = &b[3..],
},
_ => b = &b[3..],
}
}
None
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_after_start_code_3_byte() {
let buf = &[0, 0, 1, 0x67];
assert_eq!(after_start_code(buf).unwrap(), Some(3));
}
#[test]
fn test_after_start_code_4_byte() {
let buf = &[0, 0, 0, 1, 0x67];
assert_eq!(after_start_code(buf).unwrap(), Some(4));
}
#[test]
fn test_after_start_code_too_short() {
let buf = &[0, 0];
assert_eq!(after_start_code(buf).unwrap(), None);
}
#[test]
fn test_after_start_code_incomplete_4_byte() {
let buf = &[0, 0, 0];
assert_eq!(after_start_code(buf).unwrap(), None);
}
#[test]
fn test_after_start_code_invalid_first_byte() {
let buf = &[1, 0, 1];
assert!(after_start_code(buf).is_err());
}
#[test]
fn test_after_start_code_invalid_second_byte() {
let buf = &[0, 1, 1];
assert!(after_start_code(buf).is_err());
}
#[test]
fn test_after_start_code_invalid_third_byte() {
let buf = &[0, 0, 2];
assert!(after_start_code(buf).is_err());
}
#[test]
fn test_after_start_code_invalid_4_byte_pattern() {
let buf = &[0, 0, 0, 2];
assert!(after_start_code(buf).is_err());
}
#[test]
fn test_find_start_code_3_byte() {
let buf = &[0x67, 0x42, 0x00, 0x1f, 0, 0, 1];
assert_eq!(find_start_code(buf), Some((4, 3)));
}
#[test]
fn test_find_start_code_4_byte() {
let buf = &[0, 0, 0, 1, 0x67];
assert_eq!(find_start_code(buf), Some((0, 4)));
}
#[test]
fn test_find_start_code_4_byte_after_data() {
let buf = &[0x67, 0x42, 0xff, 0x1f, 0, 0, 0, 1];
assert_eq!(find_start_code(buf), Some((4, 4)));
}
#[test]
fn test_find_start_code_at_start_3_byte() {
let buf = &[0, 0, 1, 0x67];
assert_eq!(find_start_code(buf), Some((0, 3)));
}
#[test]
fn test_find_start_code_none() {
let buf = &[0x67, 0x42, 0x00, 0x1f, 0xff];
assert_eq!(find_start_code(buf), None);
}
#[test]
fn test_find_start_code_trailing_zeros() {
let buf = &[0x67, 0x42, 0x00, 0x1f, 0, 0];
assert_eq!(find_start_code(buf), None);
}
#[test]
fn test_find_start_code_edge_case_3_byte() {
let buf = &[0xff, 0, 0, 1];
assert_eq!(find_start_code(buf), Some((1, 3)));
}
#[test]
fn test_find_start_code_false_positive_avoidance() {
let buf = &[0xff, 0, 0, 0xff, 0, 0, 1];
assert_eq!(find_start_code(buf), Some((4, 3)));
}
#[test]
fn test_find_start_code_4_byte_after_nonzero() {
let buf = &[0xff, 0, 0, 0, 1];
assert_eq!(find_start_code(buf), Some((1, 4)));
}
#[test]
fn test_find_start_code_consecutive_zeros() {
let buf = &[0xff, 0, 0, 0, 0, 0, 1];
let result = find_start_code(buf);
assert!(result.is_some());
let (pos, size) = result.unwrap();
assert!(size == 3 || size == 4);
assert!(pos < buf.len());
}
#[test]
fn test_nal_iterator_simple_3_byte() {
let mut data = Bytes::from(vec![0, 0, 1, 0x67, 0x42, 0, 0, 1]);
let mut iter = NalIterator::new(&mut data);
let nal = iter.next().unwrap().unwrap();
assert_eq!(nal.as_ref(), &[0x67, 0x42]);
assert!(iter.next().is_none());
assert_eq!(data.as_ref(), &[0, 0, 1]);
}
#[test]
fn test_nal_iterator_simple_4_byte() {
let mut data = Bytes::from(vec![0, 0, 0, 1, 0x67, 0x42, 0, 0, 0, 1]);
let mut iter = NalIterator::new(&mut data);
let nal = iter.next().unwrap().unwrap();
assert_eq!(nal.as_ref(), &[0x67, 0x42]);
assert!(iter.next().is_none());
assert_eq!(data.as_ref(), &[0, 0, 0, 1]);
}
#[test]
fn test_nal_iterator_multiple_nals() {
let mut data = Bytes::from(vec![0, 0, 0, 1, 0x67, 0x42, 0, 0, 0, 1, 0x68, 0xce, 0, 0, 0, 1]);
let mut iter = NalIterator::new(&mut data);
let nal1 = iter.next().unwrap().unwrap();
assert_eq!(nal1.as_ref(), &[0x67, 0x42]);
let nal2 = iter.next().unwrap().unwrap();
assert_eq!(nal2.as_ref(), &[0x68, 0xce]);
assert!(iter.next().is_none());
assert_eq!(data.as_ref(), &[0, 0, 0, 1]);
}
#[test]
fn test_nal_iterator_realistic_h264() {
let mut data = Bytes::from(vec![
0, 0, 0, 1, 0x67, 0x42, 0x00, 0x1f, 0, 0, 0, 1, 0x68, 0xce, 0x3c, 0x80, 0, 0, 0, 1, 0x65, 0x88, 0x84, 0x00, 0, 0, 0, 1,
]);
let mut iter = NalIterator::new(&mut data);
let sps = iter.next().unwrap().unwrap();
assert_eq!(sps[0] & 0x1f, 7); assert_eq!(sps.as_ref(), &[0x67, 0x42, 0x00, 0x1f]);
let pps = iter.next().unwrap().unwrap();
assert_eq!(pps[0] & 0x1f, 8); assert_eq!(pps.as_ref(), &[0x68, 0xce, 0x3c, 0x80]);
let idr = iter.next().unwrap().unwrap();
assert_eq!(idr[0] & 0x1f, 5); assert_eq!(idr.as_ref(), &[0x65, 0x88, 0x84, 0x00]);
assert!(iter.next().is_none());
assert_eq!(data.as_ref(), &[0, 0, 0, 1]);
}
#[test]
fn test_nal_iterator_realistic_h265() {
let mut data = Bytes::from(vec![
0, 0, 0, 1, 0x40, 0x01, 0x0c, 0x01, 0, 0, 0, 1, 0x42, 0x01, 0x01, 0x60, 0, 0, 0, 1, 0x44, 0x01, 0xc0, 0xf1, 0, 0, 0, 1, 0x26, 0x01, 0x9a, 0x20, 0, 0, 0, 1,
]);
let mut iter = NalIterator::new(&mut data);
let vps = iter.next().unwrap().unwrap();
assert_eq!((vps[0] >> 1) & 0x3f, 32); assert_eq!(vps.as_ref(), &[0x40, 0x01, 0x0c, 0x01]);
let sps = iter.next().unwrap().unwrap();
assert_eq!((sps[0] >> 1) & 0x3f, 33); assert_eq!(sps.as_ref(), &[0x42, 0x01, 0x01, 0x60]);
let pps = iter.next().unwrap().unwrap();
assert_eq!((pps[0] >> 1) & 0x3f, 34); assert_eq!(pps.as_ref(), &[0x44, 0x01, 0xc0, 0xf1]);
let idr = iter.next().unwrap().unwrap();
assert_eq!((idr[0] >> 1) & 0x3f, 19); assert_eq!(idr.as_ref(), &[0x26, 0x01, 0x9a, 0x20]);
assert!(iter.next().is_none());
assert_eq!(data.as_ref(), &[0, 0, 0, 1]);
}
#[test]
fn test_nal_iterator_invalid_start() {
let mut data = Bytes::from(vec![1, 0, 1, 0x67]);
let mut iter = NalIterator::new(&mut data);
assert!(iter.next().unwrap().is_err());
assert_eq!(data.as_ref(), &[1, 0, 1, 0x67]);
}
#[test]
fn test_nal_iterator_empty_nal() {
let mut data = Bytes::from(vec![0, 0, 1, 0, 0, 1, 0x67, 0, 0, 1]);
let mut iter = NalIterator::new(&mut data);
let nal1 = iter.next().unwrap().unwrap();
assert_eq!(nal1.len(), 0);
let nal2 = iter.next().unwrap().unwrap();
assert_eq!(nal2.as_ref(), &[0x67]);
assert!(iter.next().is_none());
assert_eq!(data.as_ref(), &[0, 0, 1]);
}
#[test]
fn test_nal_iterator_nal_with_embedded_zeros() {
let mut data = Bytes::from(vec![
0, 0, 1, 0x67, 0x00, 0x00, 0x00, 0xff, 0, 0, 1, 0x68, 0, 0, 1,
]);
let mut iter = NalIterator::new(&mut data);
let nal1 = iter.next().unwrap().unwrap();
assert_eq!(nal1.as_ref(), &[0x67, 0x00, 0x00, 0x00, 0xff]);
let nal2 = iter.next().unwrap().unwrap();
assert_eq!(nal2.as_ref(), &[0x68]);
assert!(iter.next().is_none());
assert_eq!(data.as_ref(), &[0, 0, 1]);
}
#[test]
fn test_flush_after_iteration() {
let mut data = Bytes::from(vec![
0, 0, 0, 1, 0x67, 0x42, 0, 0, 0, 1, 0x68, 0xce, 0x3c, 0x80, ]);
let mut iter = NalIterator::new(&mut data);
let nal1 = iter.next().unwrap().unwrap();
assert_eq!(nal1.as_ref(), &[0x67, 0x42]);
assert!(iter.next().is_none());
let final_nal = iter.flush().unwrap().unwrap();
assert_eq!(final_nal.as_ref(), &[0x68, 0xce, 0x3c, 0x80]);
}
#[test]
fn test_flush_single_nal() {
let mut data = Bytes::from(vec![0, 0, 1, 0x67, 0x42, 0x00, 0x1f]);
let iter = NalIterator::new(&mut data);
let final_nal = iter.flush().unwrap().unwrap();
assert_eq!(final_nal.as_ref(), &[0x67, 0x42, 0x00, 0x1f]);
}
#[test]
fn test_flush_4_byte_start_code() {
let mut data = Bytes::from(vec![0, 0, 0, 1, 0x65, 0x88, 0x84, 0x00, 0xff]);
let iter = NalIterator::new(&mut data);
let final_nal = iter.flush().unwrap().unwrap();
assert_eq!(final_nal.as_ref(), &[0x65, 0x88, 0x84, 0x00, 0xff]);
}
#[test]
fn test_flush_no_start_code() {
let mut data = Bytes::from(vec![0x67, 0x42, 0x00, 0x1f]);
let iter = NalIterator::new(&mut data);
let result = iter.flush();
assert!(result.is_err());
}
#[test]
fn test_flush_empty_buffer() {
let mut data = Bytes::from(vec![]);
let iter = NalIterator::new(&mut data);
let result = iter.flush().unwrap();
assert!(result.is_none());
}
#[test]
fn test_flush_incomplete_start_code() {
let mut data = Bytes::from(vec![0, 0]);
let iter = NalIterator::new(&mut data);
let result = iter.flush().unwrap();
assert!(result.is_none());
}
#[test]
fn test_flush_multiple_nals_then_flush() {
let mut data = Bytes::from(vec![
0, 0, 0, 1, 0x67, 0x42, 0, 0, 0, 1, 0x68, 0xce, 0, 0, 0, 1, 0x65, 0x88, 0x84, ]);
let mut iter = NalIterator::new(&mut data);
let sps = iter.next().unwrap().unwrap();
assert_eq!(sps.as_ref(), &[0x67, 0x42]);
let pps = iter.next().unwrap().unwrap();
assert_eq!(pps.as_ref(), &[0x68, 0xce]);
assert!(iter.next().is_none());
let idr = iter.flush().unwrap().unwrap();
assert_eq!(idr.as_ref(), &[0x65, 0x88, 0x84]);
}
#[test]
fn test_flush_empty_final_nal() {
let mut data = Bytes::from(vec![
0, 0, 0, 1, 0x67, 0x42, 0, 0, 0, 1, ]);
let mut iter = NalIterator::new(&mut data);
let nal1 = iter.next().unwrap().unwrap();
assert_eq!(nal1.as_ref(), &[0x67, 0x42]);
assert!(iter.next().is_none());
let final_nal = iter.flush().unwrap().unwrap();
assert_eq!(final_nal.len(), 0);
}
}