1use std::fmt;
7
8#[derive(Debug, Clone, PartialEq, Eq)]
10pub enum BitstreamFilterError {
11 BufferTooShort {
13 needed: usize,
15 available: usize,
17 },
18 InvalidLengthPrefix {
20 offset: usize,
22 claimed: usize,
24 available: usize,
26 },
27 MalformedObuHeader {
29 offset: usize,
31 },
32 MalformedSequenceHeader,
34 EmptyPacket,
36 UnknownNalType(u8),
38 InvalidLengthPrefixSize(u8),
40}
41
42impl fmt::Display for BitstreamFilterError {
43 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
44 match self {
45 Self::BufferTooShort { needed, available } => {
46 write!(
47 f,
48 "buffer too short: needed {needed}, available {available}"
49 )
50 }
51 Self::InvalidLengthPrefix {
52 offset,
53 claimed,
54 available,
55 } => {
56 write!(
57 f,
58 "invalid length prefix at offset {offset}: claims {claimed} bytes but only {available} remain"
59 )
60 }
61 Self::MalformedObuHeader { offset } => {
62 write!(f, "malformed OBU header at offset {offset}")
63 }
64 Self::MalformedSequenceHeader => write!(f, "malformed AV1 sequence header"),
65 Self::EmptyPacket => write!(f, "packet is empty"),
66 Self::UnknownNalType(t) => write!(f, "unknown NAL unit type: {t}"),
67 Self::InvalidLengthPrefixSize(s) => {
68 write!(f, "invalid length prefix size: {s} (must be 1, 2, or 4)")
69 }
70 }
71 }
72}
73
74impl std::error::Error for BitstreamFilterError {}
75
76pub type BitstreamResult<T> = Result<T, BitstreamFilterError>;
78
79#[derive(Debug, Clone, Copy, PartialEq, Eq)]
85pub enum LengthPrefixSize {
86 One = 1,
88 Two = 2,
90 Four = 4,
92}
93
94impl LengthPrefixSize {
95 pub fn from_raw(raw: u8) -> BitstreamResult<Self> {
97 match raw {
98 1 => Ok(Self::One),
99 2 => Ok(Self::Two),
100 4 => Ok(Self::Four),
101 other => Err(BitstreamFilterError::InvalidLengthPrefixSize(other)),
102 }
103 }
104
105 pub fn as_usize(self) -> usize {
107 self as usize
108 }
109}
110
111const START_CODE_3: [u8; 3] = [0x00, 0x00, 0x01];
113const START_CODE_4: [u8; 4] = [0x00, 0x00, 0x00, 0x01];
115
116pub fn split_annexb(data: &[u8]) -> Vec<&[u8]> {
121 let mut nals: Vec<&[u8]> = Vec::new();
122 let mut start = 0usize;
123 let len = data.len();
124
125 if len >= 4 && data[..4] == START_CODE_4 {
127 start = 4;
128 } else if len >= 3 && data[..3] == START_CODE_3 {
129 start = 3;
130 }
131
132 let mut i = start;
133 while i + 2 < len {
134 if data[i] == 0x00 && data[i + 1] == 0x00 {
135 if i + 3 < len && data[i + 2] == 0x00 && data[i + 3] == 0x01 {
136 let nal = &data[start..i];
138 if !nal.is_empty() {
139 nals.push(nal);
140 }
141 i += 4;
142 start = i;
143 continue;
144 } else if data[i + 2] == 0x01 {
145 let nal = &data[start..i];
147 if !nal.is_empty() {
148 nals.push(nal);
149 }
150 i += 3;
151 start = i;
152 continue;
153 }
154 }
155 i += 1;
156 }
157
158 let tail = &data[start..];
160 if !tail.is_empty() {
161 nals.push(tail);
162 }
163
164 nals
165}
166
167pub fn annexb_to_avcc(data: &[u8], prefix_size: LengthPrefixSize) -> BitstreamResult<Vec<u8>> {
172 if data.is_empty() {
173 return Err(BitstreamFilterError::EmptyPacket);
174 }
175 let nals = split_annexb(data);
176 let prefix_bytes = prefix_size.as_usize();
177 let total: usize = nals.iter().map(|n| prefix_bytes + n.len()).sum();
178 let mut out = Vec::with_capacity(total);
179
180 for nal in nals {
181 let nal_len = nal.len();
182 match prefix_size {
183 LengthPrefixSize::One => {
184 out.push(nal_len as u8);
185 }
186 LengthPrefixSize::Two => {
187 out.extend_from_slice(&(nal_len as u16).to_be_bytes());
188 }
189 LengthPrefixSize::Four => {
190 out.extend_from_slice(&(nal_len as u32).to_be_bytes());
191 }
192 }
193 out.extend_from_slice(nal);
194 }
195
196 Ok(out)
197}
198
199pub fn avcc_to_annexb(data: &[u8], prefix_size: LengthPrefixSize) -> BitstreamResult<Vec<u8>> {
204 if data.is_empty() {
205 return Err(BitstreamFilterError::EmptyPacket);
206 }
207 let prefix_bytes = prefix_size.as_usize();
208 let mut out = Vec::with_capacity(data.len() + data.len() / 4);
209 let mut offset = 0usize;
210
211 while offset < data.len() {
212 if offset + prefix_bytes > data.len() {
213 return Err(BitstreamFilterError::BufferTooShort {
214 needed: offset + prefix_bytes,
215 available: data.len(),
216 });
217 }
218 let nal_len = read_be_uint(&data[offset..offset + prefix_bytes], prefix_bytes);
219 offset += prefix_bytes;
220
221 let remaining = data.len() - offset;
222 if nal_len > remaining {
223 return Err(BitstreamFilterError::InvalidLengthPrefix {
224 offset: offset - prefix_bytes,
225 claimed: nal_len,
226 available: remaining,
227 });
228 }
229 out.extend_from_slice(&START_CODE_4);
230 out.extend_from_slice(&data[offset..offset + nal_len]);
231 offset += nal_len;
232 }
233
234 Ok(out)
235}
236
237fn read_be_uint(bytes: &[u8], n: usize) -> usize {
239 match n {
240 1 => bytes[0] as usize,
241 2 => u16::from_be_bytes([bytes[0], bytes[1]]) as usize,
242 4 => u32::from_be_bytes([bytes[0], bytes[1], bytes[2], bytes[3]]) as usize,
243 _ => 0,
244 }
245}
246
247#[derive(Debug, Clone, Copy, PartialEq, Eq)]
253pub enum H264NalType {
254 NonIdrSlice,
256 IdrSlice,
258 Sei,
260 Sps,
262 Pps,
264 Aud,
266 EndOfSeq,
268 EndOfStream,
270 FillerData,
272 Other(u8),
274}
275
276impl H264NalType {
277 pub fn from_nal_byte(byte: u8) -> Self {
279 match byte & 0x1F {
280 1 => Self::NonIdrSlice,
281 5 => Self::IdrSlice,
282 6 => Self::Sei,
283 7 => Self::Sps,
284 8 => Self::Pps,
285 9 => Self::Aud,
286 10 => Self::EndOfSeq,
287 11 => Self::EndOfStream,
288 12 => Self::FillerData,
289 t => Self::Other(t),
290 }
291 }
292}
293
294#[derive(Debug, Clone)]
296pub struct NalUnit<'a> {
297 pub nal_type: H264NalType,
299 pub data: &'a [u8],
301}
302
303impl<'a> NalUnit<'a> {
304 pub fn from_raw(data: &'a [u8]) -> Option<Self> {
306 let first = *data.first()?;
307 Some(Self {
308 nal_type: H264NalType::from_nal_byte(first),
309 data,
310 })
311 }
312}
313
314pub fn extract_sps(data: &[u8]) -> Vec<NalUnit<'_>> {
316 split_annexb(data)
317 .into_iter()
318 .filter_map(NalUnit::from_raw)
319 .filter(|n| n.nal_type == H264NalType::Sps)
320 .collect()
321}
322
323pub fn extract_pps(data: &[u8]) -> Vec<NalUnit<'_>> {
325 split_annexb(data)
326 .into_iter()
327 .filter_map(NalUnit::from_raw)
328 .filter(|n| n.nal_type == H264NalType::Pps)
329 .collect()
330}
331
332pub fn extract_sps_pps(data: &[u8]) -> (Vec<Vec<u8>>, Vec<Vec<u8>>) {
334 let nals = split_annexb(data);
335 let mut sps_list = Vec::new();
336 let mut pps_list = Vec::new();
337 for nal_bytes in nals {
338 if let Some(nal) = NalUnit::from_raw(nal_bytes) {
339 match nal.nal_type {
340 H264NalType::Sps => sps_list.push(nal.data.to_vec()),
341 H264NalType::Pps => pps_list.push(nal.data.to_vec()),
342 _ => {}
343 }
344 }
345 }
346 (sps_list, pps_list)
347}
348
349#[derive(Debug, Clone, Copy, PartialEq, Eq)]
355pub enum Av1ObuType {
356 SequenceHeader,
358 TemporalDelimiter,
360 FrameHeader,
362 TileGroup,
364 Metadata,
366 Frame,
368 RedundantFrameHeader,
370 TileList,
372 Padding,
374 Reserved(u8),
376}
377
378impl Av1ObuType {
379 fn from_raw(raw: u8) -> Self {
380 match raw {
381 1 => Self::SequenceHeader,
382 2 => Self::TemporalDelimiter,
383 3 => Self::FrameHeader,
384 4 => Self::TileGroup,
385 5 => Self::Metadata,
386 6 => Self::Frame,
387 7 => Self::RedundantFrameHeader,
388 8 => Self::TileList,
389 15 => Self::Padding,
390 other => Self::Reserved(other),
391 }
392 }
393}
394
395#[derive(Debug, Clone)]
397pub struct Av1Obu {
398 pub obu_type: Av1ObuType,
400 pub payload: Vec<u8>,
402}
403
404fn read_leb128(data: &[u8], offset: usize) -> BitstreamResult<(u64, usize)> {
408 let mut result: u64 = 0;
409 let mut shift = 0u32;
410 let mut consumed = 0usize;
411 loop {
412 if offset + consumed >= data.len() {
413 return Err(BitstreamFilterError::MalformedObuHeader { offset });
414 }
415 let byte = data[offset + consumed];
416 consumed += 1;
417 result |= ((byte & 0x7F) as u64) << shift;
418 if byte & 0x80 == 0 {
419 break;
420 }
421 shift += 7;
422 if shift >= 56 {
423 return Err(BitstreamFilterError::MalformedObuHeader { offset });
424 }
425 }
426 Ok((result, consumed))
427}
428
429pub fn split_av1_obus(data: &[u8]) -> BitstreamResult<Vec<Av1Obu>> {
434 if data.is_empty() {
435 return Err(BitstreamFilterError::EmptyPacket);
436 }
437 let mut obus = Vec::new();
438 let mut offset = 0usize;
439 let len = data.len();
440
441 while offset < len {
442 if offset >= len {
443 break;
444 }
445 let header_byte = data[offset];
446 let forbidden_bit = (header_byte >> 7) & 1;
447 if forbidden_bit != 0 {
448 return Err(BitstreamFilterError::MalformedObuHeader { offset });
449 }
450 let obu_type_raw = (header_byte >> 3) & 0x0F;
451 let obu_extension_flag = (header_byte >> 2) & 1;
452 let obu_has_size_field = (header_byte >> 1) & 1;
453 offset += 1;
454
455 if obu_extension_flag == 1 {
457 if offset >= len {
458 return Err(BitstreamFilterError::MalformedObuHeader { offset });
459 }
460 offset += 1;
461 }
462
463 let payload_len = if obu_has_size_field == 1 {
464 let (sz, consumed) = read_leb128(data, offset)?;
465 offset += consumed;
466 sz as usize
467 } else {
468 len - offset
470 };
471
472 if offset + payload_len > len {
473 return Err(BitstreamFilterError::InvalidLengthPrefix {
474 offset,
475 claimed: payload_len,
476 available: len - offset,
477 });
478 }
479
480 let payload = data[offset..offset + payload_len].to_vec();
481 offset += payload_len;
482
483 obus.push(Av1Obu {
484 obu_type: Av1ObuType::from_raw(obu_type_raw),
485 payload,
486 });
487 }
488
489 Ok(obus)
490}
491
492#[derive(Debug, Clone, PartialEq, Eq)]
501pub struct Av1SequenceHeader {
502 pub seq_profile: u8,
504 pub still_picture: bool,
506 pub reduced_still_picture_header: bool,
508 pub max_frame_width: u32,
510 pub max_frame_height: u32,
512 pub high_bitdepth: bool,
514 pub twelve_bit: bool,
516 pub mono_chrome: bool,
518}
519
520struct BitReader<'a> {
522 data: &'a [u8],
523 byte_offset: usize,
524 bit_offset: u8,
525}
526
527impl<'a> BitReader<'a> {
528 fn new(data: &'a [u8]) -> Self {
529 Self {
530 data,
531 byte_offset: 0,
532 bit_offset: 0,
533 }
534 }
535
536 fn read_bit(&mut self) -> BitstreamResult<u8> {
537 if self.byte_offset >= self.data.len() {
538 return Err(BitstreamFilterError::MalformedSequenceHeader);
539 }
540 let byte = self.data[self.byte_offset];
541 let bit = (byte >> (7 - self.bit_offset)) & 1;
542 self.bit_offset += 1;
543 if self.bit_offset == 8 {
544 self.bit_offset = 0;
545 self.byte_offset += 1;
546 }
547 Ok(bit)
548 }
549
550 fn read_bits(&mut self, n: u8) -> BitstreamResult<u32> {
551 let mut val = 0u32;
552 for _ in 0..n {
553 val = (val << 1) | self.read_bit()? as u32;
554 }
555 Ok(val)
556 }
557
558 fn u(&mut self, n: u8) -> BitstreamResult<u32> {
561 self.read_bits(n)
562 }
563
564 fn f(&mut self, n: u8) -> BitstreamResult<u32> {
565 self.read_bits(n)
566 }
567}
568
569pub fn parse_av1_sequence_header(payload: &[u8]) -> BitstreamResult<Av1SequenceHeader> {
573 let mut r = BitReader::new(payload);
574
575 let seq_profile = r.f(3)? as u8;
576 let still_picture = r.f(1)? != 0;
577 let reduced_still_picture_header = r.f(1)? != 0;
578
579 let (timing_info_present, decoder_model_info_present) = if reduced_still_picture_header {
582 (false, false)
583 } else {
584 let tip = r.f(1)? != 0;
585 let dmip = if tip {
586 r.u(32)?;
589 r.u(32)?;
590 let epi = r.f(1)?;
591 if epi != 0 {
592 let _ = read_uvlc(&mut r)?;
594 }
595 r.f(1)? != 0
596 } else {
597 false
598 };
599 if dmip {
600 let _ = r.u(5)?;
603 let _ = r.u(32)?;
604 let _ = r.u(9)?;
605 }
606 (tip, dmip)
607 };
608
609 let _ = timing_info_present;
610 let _ = decoder_model_info_present;
611
612 if !reduced_still_picture_header {
614 let op_cnt = r.u(5)?; for _ in 0..=op_cnt {
616 let _op_idc = r.u(12)?;
617 let _seq_level_idx = r.u(5)?;
618 let seq_tier = if r.u(5)? > 7 { r.u(1)? } else { 0 };
619 let _ = seq_tier;
620 if decoder_model_info_present {
621 let _decoder_model_present = r.u(1)?;
622 }
624 if !reduced_still_picture_header {
625 let _initial_display_delay_present = r.u(1)?;
626 if decoder_model_info_present {
627 let _initial_display_delay_minus_1 = r.u(4)?;
628 }
629 }
630 }
631 }
632
633 let fw_bits = r.u(4)? + 1;
635 let fh_bits = r.u(4)? + 1;
636 let max_frame_width = r.u(fw_bits as u8)? + 1;
637 let max_frame_height = r.u(fh_bits as u8)? + 1;
638
639 if !reduced_still_picture_header {
641 let frame_id_numbers_present = r.u(1)?;
642 if frame_id_numbers_present != 0 {
643 let _delta_frame_id_length = r.u(4)?;
644 let _additional_frame_id_length = r.u(3)?;
645 }
646 }
647
648 let _use_128 = r.u(1)?;
650 let _enable_filter_intra = r.u(1)?;
651 let _enable_intra_edge_filter = r.u(1)?;
652
653 let high_bitdepth = r.u(1)? != 0;
656 let twelve_bit = if seq_profile == 2 && high_bitdepth {
657 r.u(1)? != 0
658 } else {
659 false
660 };
661 let mono_chrome = if seq_profile == 1 {
662 false
663 } else {
664 r.u(1)? != 0
665 };
666
667 Ok(Av1SequenceHeader {
668 seq_profile,
669 still_picture,
670 reduced_still_picture_header,
671 max_frame_width,
672 max_frame_height,
673 high_bitdepth,
674 twelve_bit,
675 mono_chrome,
676 })
677}
678
679fn read_uvlc(r: &mut BitReader<'_>) -> BitstreamResult<u32> {
681 let mut leading_zeros = 0u32;
682 loop {
683 let bit = r.read_bit()?;
684 if bit != 0 {
685 break;
686 }
687 leading_zeros += 1;
688 if leading_zeros >= 32 {
689 return Err(BitstreamFilterError::MalformedSequenceHeader);
690 }
691 }
692 if leading_zeros == 0 {
693 return Ok(0);
694 }
695 let value = r.read_bits(leading_zeros as u8)?;
696 Ok((1 << leading_zeros) + value - 1)
697}
698
699pub fn find_av1_sequence_header(data: &[u8]) -> BitstreamResult<Option<Av1SequenceHeader>> {
703 let obus = split_av1_obus(data)?;
704 for obu in obus {
705 if obu.obu_type == Av1ObuType::SequenceHeader {
706 return parse_av1_sequence_header(&obu.payload).map(Some);
707 }
708 }
709 Ok(None)
710}
711
712pub fn remove_emulation_prevention(data: &[u8]) -> Vec<u8> {
720 let mut out = Vec::with_capacity(data.len());
721 let len = data.len();
722 let mut i = 0;
723 while i < len {
724 if i + 2 < len && data[i] == 0x00 && data[i + 1] == 0x00 && data[i + 2] == 0x03 {
725 out.push(0x00);
726 out.push(0x00);
727 i += 3; } else {
729 out.push(data[i]);
730 i += 1;
731 }
732 }
733 out
734}
735
736#[cfg(test)]
741mod tests {
742 use super::*;
743
744 #[test]
747 fn test_split_annexb_single_nal_4byte_startcode() {
748 let data = [0x00, 0x00, 0x00, 0x01, 0x67, 0xAB, 0xCD];
749 let nals = split_annexb(&data);
750 assert_eq!(nals.len(), 1);
751 assert_eq!(nals[0], &[0x67, 0xAB, 0xCD]);
752 }
753
754 #[test]
755 fn test_split_annexb_multiple_nals() {
756 let data = [
757 0x00, 0x00, 0x00, 0x01, 0x67, 0x11, 0x00, 0x00, 0x01, 0x68, 0x22, ];
760 let nals = split_annexb(&data);
761 assert_eq!(nals.len(), 2);
762 assert_eq!(nals[0], &[0x67, 0x11]);
763 assert_eq!(nals[1], &[0x68, 0x22]);
764 }
765
766 #[test]
767 fn test_annexb_to_avcc_roundtrip() {
768 let sps = [0x67u8, 0x42, 0x00, 0x1E];
769 let pps = [0x68u8, 0xCE, 0x38, 0x80];
770 let mut annexb = Vec::new();
771 annexb.extend_from_slice(&[0x00, 0x00, 0x00, 0x01]);
772 annexb.extend_from_slice(&sps);
773 annexb.extend_from_slice(&[0x00, 0x00, 0x00, 0x01]);
774 annexb.extend_from_slice(&pps);
775
776 let avcc = annexb_to_avcc(&annexb, LengthPrefixSize::Four).unwrap();
777 let back = avcc_to_annexb(&avcc, LengthPrefixSize::Four).unwrap();
778
779 let nals = split_annexb(&back);
781 assert_eq!(nals.len(), 2);
782 assert_eq!(nals[0], &sps);
783 assert_eq!(nals[1], &pps);
784 }
785
786 #[test]
787 fn test_avcc_to_annexb_two_byte_prefix() {
788 let nal = [0x65u8, 0x11, 0x22];
790 let mut avcc = Vec::new();
791 avcc.extend_from_slice(&(3u16).to_be_bytes());
792 avcc.extend_from_slice(&nal);
793 let result = avcc_to_annexb(&avcc, LengthPrefixSize::Two).unwrap();
794 assert_eq!(&result[..4], &[0x00, 0x00, 0x00, 0x01]);
795 assert_eq!(&result[4..], &nal);
796 }
797
798 #[test]
799 fn test_avcc_invalid_length_prefix_error() {
800 let mut avcc = Vec::new();
802 avcc.extend_from_slice(&(100u32).to_be_bytes());
803 avcc.extend_from_slice(&[0xAA, 0xBB]);
804 let err = avcc_to_annexb(&avcc, LengthPrefixSize::Four).unwrap_err();
805 assert!(matches!(
806 err,
807 BitstreamFilterError::InvalidLengthPrefix { .. }
808 ));
809 }
810
811 #[test]
814 fn test_extract_sps_pps() {
815 let mut stream = Vec::new();
816 stream.extend_from_slice(&[0x00, 0x00, 0x00, 0x01, 0x67, 0x42, 0x00, 0x1E]);
818 stream.extend_from_slice(&[0x00, 0x00, 0x00, 0x01, 0x68, 0xCE]);
820 stream.extend_from_slice(&[0x00, 0x00, 0x00, 0x01, 0x65, 0x88]);
822
823 let (sps, pps) = extract_sps_pps(&stream);
824 assert_eq!(sps.len(), 1);
825 assert_eq!(pps.len(), 1);
826 assert_eq!(sps[0][0], 0x67);
828 assert_eq!(pps[0][0], 0x68);
829 }
830
831 #[test]
834 fn test_remove_emulation_prevention() {
835 let input = [0x00u8, 0x00, 0x03, 0x01, 0xFF];
836 let output = remove_emulation_prevention(&input);
837 assert_eq!(output, [0x00, 0x00, 0x01, 0xFF]);
838 }
839
840 #[test]
843 fn test_split_av1_obus_sequence_header() {
844 let payload = [0x00u8; 4]; let mut data = Vec::new();
850 data.push(0x0A); data.push(0x04);
853 data.extend_from_slice(&payload);
854
855 let obus = split_av1_obus(&data).unwrap();
856 assert_eq!(obus.len(), 1);
857 assert_eq!(obus[0].obu_type, Av1ObuType::SequenceHeader);
858 assert_eq!(obus[0].payload, payload);
859 }
860
861 #[test]
862 fn test_split_av1_obus_empty_error() {
863 let err = split_av1_obus(&[]).unwrap_err();
864 assert_eq!(err, BitstreamFilterError::EmptyPacket);
865 }
866
867 #[test]
868 fn test_split_av1_obus_multiple() {
869 let mut data = Vec::new();
873 data.push(0x12); data.push(0x00); data.push(0x22); data.push(0x02); data.push(0xAA);
878 data.push(0xBB);
879
880 let obus = split_av1_obus(&data).unwrap();
881 assert_eq!(obus.len(), 2);
882 assert_eq!(obus[0].obu_type, Av1ObuType::TemporalDelimiter);
883 assert_eq!(obus[1].obu_type, Av1ObuType::TileGroup);
884 assert_eq!(obus[1].payload, [0xAA, 0xBB]);
885 }
886
887 #[test]
888 fn test_leb128_multi_byte() {
889 let data = [0xACu8, 0x02];
891 let (val, consumed) = read_leb128(&data, 0).unwrap();
892 assert_eq!(val, 300);
893 assert_eq!(consumed, 2);
894 }
895
896 #[test]
897 fn test_empty_packet_error() {
898 assert_eq!(
899 annexb_to_avcc(&[], LengthPrefixSize::Four).unwrap_err(),
900 BitstreamFilterError::EmptyPacket
901 );
902 assert_eq!(
903 avcc_to_annexb(&[], LengthPrefixSize::Four).unwrap_err(),
904 BitstreamFilterError::EmptyPacket
905 );
906 }
907
908 #[test]
909 fn test_length_prefix_size_from_raw() {
910 assert_eq!(
911 LengthPrefixSize::from_raw(1).unwrap(),
912 LengthPrefixSize::One
913 );
914 assert_eq!(
915 LengthPrefixSize::from_raw(2).unwrap(),
916 LengthPrefixSize::Two
917 );
918 assert_eq!(
919 LengthPrefixSize::from_raw(4).unwrap(),
920 LengthPrefixSize::Four
921 );
922 assert!(LengthPrefixSize::from_raw(3).is_err());
923 }
924}