Skip to main content

oximedia_codec/
bitstream_filter.rs

1//! Bitstream filters for codec-level NAL and OBU transformations.
2//!
3//! Provides conversions between AnnexB and AVCC/HVCC packet formats,
4//! SPS/PPS/VPS extraction, and AV1 OBU sequence header parsing.
5
6use std::fmt;
7
8/// Errors that can occur during bitstream filtering operations.
9#[derive(Debug, Clone, PartialEq, Eq)]
10pub enum BitstreamFilterError {
11    /// Input buffer is too short for the requested operation.
12    BufferTooShort {
13        /// Number of bytes needed.
14        needed: usize,
15        /// Number of bytes actually available.
16        available: usize,
17    },
18    /// A NAL unit length prefix encoded a size that exceeds available data.
19    InvalidLengthPrefix {
20        /// Byte offset of the length prefix within the buffer.
21        offset: usize,
22        /// Number of bytes claimed by the prefix.
23        claimed: usize,
24        /// Number of bytes actually remaining after the prefix.
25        available: usize,
26    },
27    /// An OBU header or extension byte was malformed.
28    MalformedObuHeader {
29        /// Byte offset of the malformed header within the buffer.
30        offset: usize,
31    },
32    /// The AV1 sequence header RBSP is malformed.
33    MalformedSequenceHeader,
34    /// Attempted to convert an empty packet.
35    EmptyPacket,
36    /// NAL unit type is not recognized for the requested operation.
37    UnknownNalType(u8),
38    /// Length prefix size is not 1, 2, or 4.
39    InvalidLengthPrefixSize(u8),
40}
41
42impl fmt::Display for BitstreamFilterError {
43    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
44        match self {
45            Self::BufferTooShort { needed, available } => {
46                write!(
47                    f,
48                    "buffer too short: needed {needed}, available {available}"
49                )
50            }
51            Self::InvalidLengthPrefix {
52                offset,
53                claimed,
54                available,
55            } => {
56                write!(
57                    f,
58                    "invalid length prefix at offset {offset}: claims {claimed} bytes but only {available} remain"
59                )
60            }
61            Self::MalformedObuHeader { offset } => {
62                write!(f, "malformed OBU header at offset {offset}")
63            }
64            Self::MalformedSequenceHeader => write!(f, "malformed AV1 sequence header"),
65            Self::EmptyPacket => write!(f, "packet is empty"),
66            Self::UnknownNalType(t) => write!(f, "unknown NAL unit type: {t}"),
67            Self::InvalidLengthPrefixSize(s) => {
68                write!(f, "invalid length prefix size: {s} (must be 1, 2, or 4)")
69            }
70        }
71    }
72}
73
74impl std::error::Error for BitstreamFilterError {}
75
76/// Result type for bitstream filter operations.
77pub type BitstreamResult<T> = Result<T, BitstreamFilterError>;
78
79// ---------------------------------------------------------------------------
80// AnnexB ↔ AVCC conversion
81// ---------------------------------------------------------------------------
82
83/// Number of bytes used to encode each NAL length in AVCC/HVCC format.
84#[derive(Debug, Clone, Copy, PartialEq, Eq)]
85pub enum LengthPrefixSize {
86    /// 1-byte length prefix (max NAL size: 255).
87    One = 1,
88    /// 2-byte big-endian length prefix (max NAL size: 65535).
89    Two = 2,
90    /// 4-byte big-endian length prefix (max NAL size: 4 GiB − 1).
91    Four = 4,
92}
93
94impl LengthPrefixSize {
95    /// Construct from the raw byte count used in container metadata.
96    pub fn from_raw(raw: u8) -> BitstreamResult<Self> {
97        match raw {
98            1 => Ok(Self::One),
99            2 => Ok(Self::Two),
100            4 => Ok(Self::Four),
101            other => Err(BitstreamFilterError::InvalidLengthPrefixSize(other)),
102        }
103    }
104
105    /// Returns the numeric value.
106    pub fn as_usize(self) -> usize {
107        self as usize
108    }
109}
110
111/// The AnnexB start-code prefix (3-byte short form).
112const START_CODE_3: [u8; 3] = [0x00, 0x00, 0x01];
113/// The AnnexB start-code prefix (4-byte long form).
114const START_CODE_4: [u8; 4] = [0x00, 0x00, 0x00, 0x01];
115
116/// Split an AnnexB byte stream into individual raw NAL unit byte slices.
117///
118/// Both 3-byte (`00 00 01`) and 4-byte (`00 00 00 01`) start codes are
119/// recognised. The returned slices do **not** include the start code.
120pub fn split_annexb(data: &[u8]) -> Vec<&[u8]> {
121    let mut nals: Vec<&[u8]> = Vec::new();
122    let mut start = 0usize;
123    let len = data.len();
124
125    // Skip a leading start code if present.
126    if len >= 4 && data[..4] == START_CODE_4 {
127        start = 4;
128    } else if len >= 3 && data[..3] == START_CODE_3 {
129        start = 3;
130    }
131
132    let mut i = start;
133    while i + 2 < len {
134        if data[i] == 0x00 && data[i + 1] == 0x00 {
135            if i + 3 < len && data[i + 2] == 0x00 && data[i + 3] == 0x01 {
136                // 4-byte start code
137                let nal = &data[start..i];
138                if !nal.is_empty() {
139                    nals.push(nal);
140                }
141                i += 4;
142                start = i;
143                continue;
144            } else if data[i + 2] == 0x01 {
145                // 3-byte start code
146                let nal = &data[start..i];
147                if !nal.is_empty() {
148                    nals.push(nal);
149                }
150                i += 3;
151                start = i;
152                continue;
153            }
154        }
155        i += 1;
156    }
157
158    // Tail NAL unit.
159    let tail = &data[start..];
160    if !tail.is_empty() {
161        nals.push(tail);
162    }
163
164    nals
165}
166
167/// Convert an AnnexB packet to AVCC/HVCC length-prefixed format.
168///
169/// Each NAL unit delimited by start codes is prefixed with a big-endian
170/// integer whose width is determined by `prefix_size`.
171pub fn annexb_to_avcc(data: &[u8], prefix_size: LengthPrefixSize) -> BitstreamResult<Vec<u8>> {
172    if data.is_empty() {
173        return Err(BitstreamFilterError::EmptyPacket);
174    }
175    let nals = split_annexb(data);
176    let prefix_bytes = prefix_size.as_usize();
177    let total: usize = nals.iter().map(|n| prefix_bytes + n.len()).sum();
178    let mut out = Vec::with_capacity(total);
179
180    for nal in nals {
181        let nal_len = nal.len();
182        match prefix_size {
183            LengthPrefixSize::One => {
184                out.push(nal_len as u8);
185            }
186            LengthPrefixSize::Two => {
187                out.extend_from_slice(&(nal_len as u16).to_be_bytes());
188            }
189            LengthPrefixSize::Four => {
190                out.extend_from_slice(&(nal_len as u32).to_be_bytes());
191            }
192        }
193        out.extend_from_slice(nal);
194    }
195
196    Ok(out)
197}
198
199/// Convert an AVCC/HVCC length-prefixed packet to AnnexB start-code format.
200///
201/// `prefix_size` must match the value in the container's parameter set record
202/// (`AVCDecoderConfigurationRecord.lengthSizeMinusOne + 1`).
203pub fn avcc_to_annexb(data: &[u8], prefix_size: LengthPrefixSize) -> BitstreamResult<Vec<u8>> {
204    if data.is_empty() {
205        return Err(BitstreamFilterError::EmptyPacket);
206    }
207    let prefix_bytes = prefix_size.as_usize();
208    let mut out = Vec::with_capacity(data.len() + data.len() / 4);
209    let mut offset = 0usize;
210
211    while offset < data.len() {
212        if offset + prefix_bytes > data.len() {
213            return Err(BitstreamFilterError::BufferTooShort {
214                needed: offset + prefix_bytes,
215                available: data.len(),
216            });
217        }
218        let nal_len = read_be_uint(&data[offset..offset + prefix_bytes], prefix_bytes);
219        offset += prefix_bytes;
220
221        let remaining = data.len() - offset;
222        if nal_len > remaining {
223            return Err(BitstreamFilterError::InvalidLengthPrefix {
224                offset: offset - prefix_bytes,
225                claimed: nal_len,
226                available: remaining,
227            });
228        }
229        out.extend_from_slice(&START_CODE_4);
230        out.extend_from_slice(&data[offset..offset + nal_len]);
231        offset += nal_len;
232    }
233
234    Ok(out)
235}
236
237/// Read a big-endian unsigned integer of `n` bytes (n ∈ {1, 2, 4}).
238fn read_be_uint(bytes: &[u8], n: usize) -> usize {
239    match n {
240        1 => bytes[0] as usize,
241        2 => u16::from_be_bytes([bytes[0], bytes[1]]) as usize,
242        4 => u32::from_be_bytes([bytes[0], bytes[1], bytes[2], bytes[3]]) as usize,
243        _ => 0,
244    }
245}
246
247// ---------------------------------------------------------------------------
248// H.264 / H.265 NAL unit type extraction
249// ---------------------------------------------------------------------------
250
251/// H.264 NAL unit types relevant for parameter set extraction.
252#[derive(Debug, Clone, Copy, PartialEq, Eq)]
253pub enum H264NalType {
254    /// Non-IDR slice (nal_unit_type 1).
255    NonIdrSlice,
256    /// IDR slice (nal_unit_type 5).
257    IdrSlice,
258    /// Supplemental enhancement information (nal_unit_type 6).
259    Sei,
260    /// Sequence parameter set (nal_unit_type 7).
261    Sps,
262    /// Picture parameter set (nal_unit_type 8).
263    Pps,
264    /// Access unit delimiter (nal_unit_type 9).
265    Aud,
266    /// End of sequence (nal_unit_type 10).
267    EndOfSeq,
268    /// End of stream (nal_unit_type 11).
269    EndOfStream,
270    /// Filler data (nal_unit_type 12).
271    FillerData,
272    /// Other / reserved type.
273    Other(u8),
274}
275
276impl H264NalType {
277    /// Parse from the first byte of a raw NAL unit (forbidden_zero_bit + nal_ref_idc + type).
278    pub fn from_nal_byte(byte: u8) -> Self {
279        match byte & 0x1F {
280            1 => Self::NonIdrSlice,
281            5 => Self::IdrSlice,
282            6 => Self::Sei,
283            7 => Self::Sps,
284            8 => Self::Pps,
285            9 => Self::Aud,
286            10 => Self::EndOfSeq,
287            11 => Self::EndOfStream,
288            12 => Self::FillerData,
289            t => Self::Other(t),
290        }
291    }
292}
293
294/// A parsed, classified NAL unit from an H.264 or H.265 bitstream.
295#[derive(Debug, Clone)]
296pub struct NalUnit<'a> {
297    /// Parsed type (H.264 interpretation).
298    pub nal_type: H264NalType,
299    /// Raw bytes of the NAL unit (excluding any start code).
300    pub data: &'a [u8],
301}
302
303impl<'a> NalUnit<'a> {
304    /// Construct from a raw NAL slice (must be non-empty).
305    pub fn from_raw(data: &'a [u8]) -> Option<Self> {
306        let first = *data.first()?;
307        Some(Self {
308            nal_type: H264NalType::from_nal_byte(first),
309            data,
310        })
311    }
312}
313
314/// Extract all SPS NAL units from an AnnexB bitstream.
315pub fn extract_sps(data: &[u8]) -> Vec<NalUnit<'_>> {
316    split_annexb(data)
317        .into_iter()
318        .filter_map(NalUnit::from_raw)
319        .filter(|n| n.nal_type == H264NalType::Sps)
320        .collect()
321}
322
323/// Extract all PPS NAL units from an AnnexB bitstream.
324pub fn extract_pps(data: &[u8]) -> Vec<NalUnit<'_>> {
325    split_annexb(data)
326        .into_iter()
327        .filter_map(NalUnit::from_raw)
328        .filter(|n| n.nal_type == H264NalType::Pps)
329        .collect()
330}
331
332/// Extract both SPS and PPS NAL units from an AnnexB bitstream as owned bytes.
333pub fn extract_sps_pps(data: &[u8]) -> (Vec<Vec<u8>>, Vec<Vec<u8>>) {
334    let nals = split_annexb(data);
335    let mut sps_list = Vec::new();
336    let mut pps_list = Vec::new();
337    for nal_bytes in nals {
338        if let Some(nal) = NalUnit::from_raw(nal_bytes) {
339            match nal.nal_type {
340                H264NalType::Sps => sps_list.push(nal.data.to_vec()),
341                H264NalType::Pps => pps_list.push(nal.data.to_vec()),
342                _ => {}
343            }
344        }
345    }
346    (sps_list, pps_list)
347}
348
349// ---------------------------------------------------------------------------
350// AV1 OBU parsing
351// ---------------------------------------------------------------------------
352
353/// AV1 Open Bitstream Unit (OBU) types per the AV1 specification §5.3.2.
354#[derive(Debug, Clone, Copy, PartialEq, Eq)]
355pub enum Av1ObuType {
356    /// Sequence header OBU (type 1).
357    SequenceHeader,
358    /// Temporal delimiter OBU (type 2).
359    TemporalDelimiter,
360    /// Frame header OBU (type 3).
361    FrameHeader,
362    /// Tile group OBU (type 4).
363    TileGroup,
364    /// Metadata OBU (type 5).
365    Metadata,
366    /// Frame OBU (type 6).
367    Frame,
368    /// Redundant frame header OBU (type 7).
369    RedundantFrameHeader,
370    /// Tile list OBU (type 8).
371    TileList,
372    /// Padding OBU (type 15).
373    Padding,
374    /// Reserved OBU type.
375    Reserved(u8),
376}
377
378impl Av1ObuType {
379    fn from_raw(raw: u8) -> Self {
380        match raw {
381            1 => Self::SequenceHeader,
382            2 => Self::TemporalDelimiter,
383            3 => Self::FrameHeader,
384            4 => Self::TileGroup,
385            5 => Self::Metadata,
386            6 => Self::Frame,
387            7 => Self::RedundantFrameHeader,
388            8 => Self::TileList,
389            15 => Self::Padding,
390            other => Self::Reserved(other),
391        }
392    }
393}
394
395/// A parsed AV1 OBU unit with its type and payload bytes.
396#[derive(Debug, Clone)]
397pub struct Av1Obu {
398    /// The OBU type.
399    pub obu_type: Av1ObuType,
400    /// The payload bytes (excluding the OBU header and size field).
401    pub payload: Vec<u8>,
402}
403
404/// Parse a low-level unsigned LEB128 value from a byte slice.
405///
406/// Returns `(value, bytes_consumed)` or an error if the encoding is malformed.
407fn read_leb128(data: &[u8], offset: usize) -> BitstreamResult<(u64, usize)> {
408    let mut result: u64 = 0;
409    let mut shift = 0u32;
410    let mut consumed = 0usize;
411    loop {
412        if offset + consumed >= data.len() {
413            return Err(BitstreamFilterError::MalformedObuHeader { offset });
414        }
415        let byte = data[offset + consumed];
416        consumed += 1;
417        result |= ((byte & 0x7F) as u64) << shift;
418        if byte & 0x80 == 0 {
419            break;
420        }
421        shift += 7;
422        if shift >= 56 {
423            return Err(BitstreamFilterError::MalformedObuHeader { offset });
424        }
425    }
426    Ok((result, consumed))
427}
428
429/// Split a contiguous AV1 bitstream (e.g., from an ISOBMFF `av01` sample) into
430/// individual OBU units.
431///
432/// Assumes the standard OBU framing with `obu_has_size_field = 1`.
433pub fn split_av1_obus(data: &[u8]) -> BitstreamResult<Vec<Av1Obu>> {
434    if data.is_empty() {
435        return Err(BitstreamFilterError::EmptyPacket);
436    }
437    let mut obus = Vec::new();
438    let mut offset = 0usize;
439    let len = data.len();
440
441    while offset < len {
442        if offset >= len {
443            break;
444        }
445        let header_byte = data[offset];
446        let forbidden_bit = (header_byte >> 7) & 1;
447        if forbidden_bit != 0 {
448            return Err(BitstreamFilterError::MalformedObuHeader { offset });
449        }
450        let obu_type_raw = (header_byte >> 3) & 0x0F;
451        let obu_extension_flag = (header_byte >> 2) & 1;
452        let obu_has_size_field = (header_byte >> 1) & 1;
453        offset += 1;
454
455        // Skip extension byte if present.
456        if obu_extension_flag == 1 {
457            if offset >= len {
458                return Err(BitstreamFilterError::MalformedObuHeader { offset });
459            }
460            offset += 1;
461        }
462
463        let payload_len = if obu_has_size_field == 1 {
464            let (sz, consumed) = read_leb128(data, offset)?;
465            offset += consumed;
466            sz as usize
467        } else {
468            // Without a size field the OBU extends to end of data.
469            len - offset
470        };
471
472        if offset + payload_len > len {
473            return Err(BitstreamFilterError::InvalidLengthPrefix {
474                offset,
475                claimed: payload_len,
476                available: len - offset,
477            });
478        }
479
480        let payload = data[offset..offset + payload_len].to_vec();
481        offset += payload_len;
482
483        obus.push(Av1Obu {
484            obu_type: Av1ObuType::from_raw(obu_type_raw),
485            payload,
486        });
487    }
488
489    Ok(obus)
490}
491
492// ---------------------------------------------------------------------------
493// AV1 Sequence Header quick-parse
494// ---------------------------------------------------------------------------
495
496/// High-level fields extracted from an AV1 sequence header OBU.
497///
498/// Only the fields most useful for container-level metadata are decoded;
499/// a full reference decoder is out of scope for this filter layer.
500#[derive(Debug, Clone, PartialEq, Eq)]
501pub struct Av1SequenceHeader {
502    /// `seq_profile` (3 bits): 0 = Main, 1 = High, 2 = Professional.
503    pub seq_profile: u8,
504    /// `still_picture` flag.
505    pub still_picture: bool,
506    /// `reduced_still_picture_header` flag.
507    pub reduced_still_picture_header: bool,
508    /// Maximum frame width in pixels (decoded from the header's width bits).
509    pub max_frame_width: u32,
510    /// Maximum frame height in pixels.
511    pub max_frame_height: u32,
512    /// 12-bit color mode flag (`high_bitdepth`).
513    pub high_bitdepth: bool,
514    /// Twelve-bit video flag (only valid when `high_bitdepth` is true).
515    pub twelve_bit: bool,
516    /// `mono_chrome` flag.
517    pub mono_chrome: bool,
518}
519
520/// A simple bit-level reader for big-endian bitstreams.
521struct BitReader<'a> {
522    data: &'a [u8],
523    byte_offset: usize,
524    bit_offset: u8,
525}
526
527impl<'a> BitReader<'a> {
528    fn new(data: &'a [u8]) -> Self {
529        Self {
530            data,
531            byte_offset: 0,
532            bit_offset: 0,
533        }
534    }
535
536    fn read_bit(&mut self) -> BitstreamResult<u8> {
537        if self.byte_offset >= self.data.len() {
538            return Err(BitstreamFilterError::MalformedSequenceHeader);
539        }
540        let byte = self.data[self.byte_offset];
541        let bit = (byte >> (7 - self.bit_offset)) & 1;
542        self.bit_offset += 1;
543        if self.bit_offset == 8 {
544            self.bit_offset = 0;
545            self.byte_offset += 1;
546        }
547        Ok(bit)
548    }
549
550    fn read_bits(&mut self, n: u8) -> BitstreamResult<u32> {
551        let mut val = 0u32;
552        for _ in 0..n {
553            val = (val << 1) | self.read_bit()? as u32;
554        }
555        Ok(val)
556    }
557
558    /// Unsigned LEB128 read (for `uvlc()` is not needed here; we use fixed-width reads).
559    /// Read a "u(n)" syntax element.
560    fn u(&mut self, n: u8) -> BitstreamResult<u32> {
561        self.read_bits(n)
562    }
563
564    fn f(&mut self, n: u8) -> BitstreamResult<u32> {
565        self.read_bits(n)
566    }
567}
568
569/// Parse the fields of an AV1 sequence header OBU payload.
570///
571/// Returns `None` if the payload is too short to contain a valid header.
572pub fn parse_av1_sequence_header(payload: &[u8]) -> BitstreamResult<Av1SequenceHeader> {
573    let mut r = BitReader::new(payload);
574
575    let seq_profile = r.f(3)? as u8;
576    let still_picture = r.f(1)? != 0;
577    let reduced_still_picture_header = r.f(1)? != 0;
578
579    // timing_info_present_flag and decoder model etc. — skip when
580    // reduced_still_picture_header is set.
581    let (timing_info_present, decoder_model_info_present) = if reduced_still_picture_header {
582        (false, false)
583    } else {
584        let tip = r.f(1)? != 0;
585        let dmip = if tip {
586            // timing_info() — simplified skip: num_units_in_display_tick (32),
587            // time_scale (32), equal_picture_interval (1) + conditional.
588            r.u(32)?;
589            r.u(32)?;
590            let epi = r.f(1)?;
591            if epi != 0 {
592                // pts_num_ticks_per_picture_minus_1 (uvlc) — skip a few bits.
593                let _ = read_uvlc(&mut r)?;
594            }
595            r.f(1)? != 0
596        } else {
597            false
598        };
599        if dmip {
600            // decoder_model_info(): buffer_delay_length_minus_1(5), ...
601            // We do a best-effort skip of 40 bits (approximate).
602            let _ = r.u(5)?;
603            let _ = r.u(32)?;
604            let _ = r.u(9)?;
605        }
606        (tip, dmip)
607    };
608
609    let _ = timing_info_present;
610    let _ = decoder_model_info_present;
611
612    // operating_points_cnt_minus_1 — when not reduced header.
613    if !reduced_still_picture_header {
614        let op_cnt = r.u(5)?; // operating_points_cnt_minus_1 (5 bits)
615        for _ in 0..=op_cnt {
616            let _op_idc = r.u(12)?;
617            let _seq_level_idx = r.u(5)?;
618            let seq_tier = if r.u(5)? > 7 { r.u(1)? } else { 0 };
619            let _ = seq_tier;
620            if decoder_model_info_present {
621                let _decoder_model_present = r.u(1)?;
622                // skip operating_parameters_info if present
623            }
624            if !reduced_still_picture_header {
625                let _initial_display_delay_present = r.u(1)?;
626                if decoder_model_info_present {
627                    let _initial_display_delay_minus_1 = r.u(4)?;
628                }
629            }
630        }
631    }
632
633    // frame_width_bits_minus_1 (4 bits), frame_height_bits_minus_1 (4 bits)
634    let fw_bits = r.u(4)? + 1;
635    let fh_bits = r.u(4)? + 1;
636    let max_frame_width = r.u(fw_bits as u8)? + 1;
637    let max_frame_height = r.u(fh_bits as u8)? + 1;
638
639    // frame_id_numbers_present (1 bit) — if not reduced.
640    if !reduced_still_picture_header {
641        let frame_id_numbers_present = r.u(1)?;
642        if frame_id_numbers_present != 0 {
643            let _delta_frame_id_length = r.u(4)?;
644            let _additional_frame_id_length = r.u(3)?;
645        }
646    }
647
648    // use_128x128_superblock (1), enable_filter_intra (1), enable_intra_edge_filter (1)
649    let _use_128 = r.u(1)?;
650    let _enable_filter_intra = r.u(1)?;
651    let _enable_intra_edge_filter = r.u(1)?;
652
653    // Skip remaining non-color fields for brevity.
654    // color_config()
655    let high_bitdepth = r.u(1)? != 0;
656    let twelve_bit = if seq_profile == 2 && high_bitdepth {
657        r.u(1)? != 0
658    } else {
659        false
660    };
661    let mono_chrome = if seq_profile == 1 {
662        false
663    } else {
664        r.u(1)? != 0
665    };
666
667    Ok(Av1SequenceHeader {
668        seq_profile,
669        still_picture,
670        reduced_still_picture_header,
671        max_frame_width,
672        max_frame_height,
673        high_bitdepth,
674        twelve_bit,
675        mono_chrome,
676    })
677}
678
679/// Read an unsigned variable-length code (uvlc) value.
680fn read_uvlc(r: &mut BitReader<'_>) -> BitstreamResult<u32> {
681    let mut leading_zeros = 0u32;
682    loop {
683        let bit = r.read_bit()?;
684        if bit != 0 {
685            break;
686        }
687        leading_zeros += 1;
688        if leading_zeros >= 32 {
689            return Err(BitstreamFilterError::MalformedSequenceHeader);
690        }
691    }
692    if leading_zeros == 0 {
693        return Ok(0);
694    }
695    let value = r.read_bits(leading_zeros as u8)?;
696    Ok((1 << leading_zeros) + value - 1)
697}
698
699/// Find and parse the first AV1 sequence header OBU found in a bitstream.
700///
701/// Returns `None` if no sequence header OBU is present.
702pub fn find_av1_sequence_header(data: &[u8]) -> BitstreamResult<Option<Av1SequenceHeader>> {
703    let obus = split_av1_obus(data)?;
704    for obu in obus {
705        if obu.obu_type == Av1ObuType::SequenceHeader {
706            return parse_av1_sequence_header(&obu.payload).map(Some);
707        }
708    }
709    Ok(None)
710}
711
712// ---------------------------------------------------------------------------
713// High-level helper: remove emulation-prevention bytes (RBSP decoding)
714// ---------------------------------------------------------------------------
715
716/// Remove H.264/H.265 emulation prevention bytes (`0x03`) from a RBSP.
717///
718/// Sequences `00 00 03 {00, 01, 02, 03}` have the `03` byte stripped.
719pub fn remove_emulation_prevention(data: &[u8]) -> Vec<u8> {
720    let mut out = Vec::with_capacity(data.len());
721    let len = data.len();
722    let mut i = 0;
723    while i < len {
724        if i + 2 < len && data[i] == 0x00 && data[i + 1] == 0x00 && data[i + 2] == 0x03 {
725            out.push(0x00);
726            out.push(0x00);
727            i += 3; // Skip the emulation prevention byte.
728        } else {
729            out.push(data[i]);
730            i += 1;
731        }
732    }
733    out
734}
735
736// ---------------------------------------------------------------------------
737// Tests
738// ---------------------------------------------------------------------------
739
740#[cfg(test)]
741mod tests {
742    use super::*;
743
744    // --- AnnexB / AVCC round-trip ---
745
746    #[test]
747    fn test_split_annexb_single_nal_4byte_startcode() {
748        let data = [0x00, 0x00, 0x00, 0x01, 0x67, 0xAB, 0xCD];
749        let nals = split_annexb(&data);
750        assert_eq!(nals.len(), 1);
751        assert_eq!(nals[0], &[0x67, 0xAB, 0xCD]);
752    }
753
754    #[test]
755    fn test_split_annexb_multiple_nals() {
756        let data = [
757            0x00, 0x00, 0x00, 0x01, 0x67, 0x11, // SPS
758            0x00, 0x00, 0x01, 0x68, 0x22, // PPS (3-byte start code)
759        ];
760        let nals = split_annexb(&data);
761        assert_eq!(nals.len(), 2);
762        assert_eq!(nals[0], &[0x67, 0x11]);
763        assert_eq!(nals[1], &[0x68, 0x22]);
764    }
765
766    #[test]
767    fn test_annexb_to_avcc_roundtrip() {
768        let sps = [0x67u8, 0x42, 0x00, 0x1E];
769        let pps = [0x68u8, 0xCE, 0x38, 0x80];
770        let mut annexb = Vec::new();
771        annexb.extend_from_slice(&[0x00, 0x00, 0x00, 0x01]);
772        annexb.extend_from_slice(&sps);
773        annexb.extend_from_slice(&[0x00, 0x00, 0x00, 0x01]);
774        annexb.extend_from_slice(&pps);
775
776        let avcc = annexb_to_avcc(&annexb, LengthPrefixSize::Four).unwrap();
777        let back = avcc_to_annexb(&avcc, LengthPrefixSize::Four).unwrap();
778
779        // The round-trip should contain both NAL units.
780        let nals = split_annexb(&back);
781        assert_eq!(nals.len(), 2);
782        assert_eq!(nals[0], &sps);
783        assert_eq!(nals[1], &pps);
784    }
785
786    #[test]
787    fn test_avcc_to_annexb_two_byte_prefix() {
788        // Craft a 2-byte AVCC packet with one 3-byte NAL.
789        let nal = [0x65u8, 0x11, 0x22];
790        let mut avcc = Vec::new();
791        avcc.extend_from_slice(&(3u16).to_be_bytes());
792        avcc.extend_from_slice(&nal);
793        let result = avcc_to_annexb(&avcc, LengthPrefixSize::Two).unwrap();
794        assert_eq!(&result[..4], &[0x00, 0x00, 0x00, 0x01]);
795        assert_eq!(&result[4..], &nal);
796    }
797
798    #[test]
799    fn test_avcc_invalid_length_prefix_error() {
800        // Claim 100 bytes but only 2 available.
801        let mut avcc = Vec::new();
802        avcc.extend_from_slice(&(100u32).to_be_bytes());
803        avcc.extend_from_slice(&[0xAA, 0xBB]);
804        let err = avcc_to_annexb(&avcc, LengthPrefixSize::Four).unwrap_err();
805        assert!(matches!(
806            err,
807            BitstreamFilterError::InvalidLengthPrefix { .. }
808        ));
809    }
810
811    // --- SPS/PPS extraction ---
812
813    #[test]
814    fn test_extract_sps_pps() {
815        let mut stream = Vec::new();
816        // SPS (nal_type = 7)
817        stream.extend_from_slice(&[0x00, 0x00, 0x00, 0x01, 0x67, 0x42, 0x00, 0x1E]);
818        // PPS (nal_type = 8)
819        stream.extend_from_slice(&[0x00, 0x00, 0x00, 0x01, 0x68, 0xCE]);
820        // IDR slice (nal_type = 5)
821        stream.extend_from_slice(&[0x00, 0x00, 0x00, 0x01, 0x65, 0x88]);
822
823        let (sps, pps) = extract_sps_pps(&stream);
824        assert_eq!(sps.len(), 1);
825        assert_eq!(pps.len(), 1);
826        // First byte of extracted SPS should be 0x67.
827        assert_eq!(sps[0][0], 0x67);
828        assert_eq!(pps[0][0], 0x68);
829    }
830
831    // --- Emulation prevention removal ---
832
833    #[test]
834    fn test_remove_emulation_prevention() {
835        let input = [0x00u8, 0x00, 0x03, 0x01, 0xFF];
836        let output = remove_emulation_prevention(&input);
837        assert_eq!(output, [0x00, 0x00, 0x01, 0xFF]);
838    }
839
840    // --- AV1 OBU splitting ---
841
842    #[test]
843    fn test_split_av1_obus_sequence_header() {
844        // Minimal synthetic AV1 packet: one OBU.
845        // OBU header: type=1 (sequence), no extension, has_size_field=1.
846        // header byte: forbidden=0, type=1, ext=0, has_size=1, reserved=0
847        // = 0b0_0001_0_1_0 = 0x0A
848        let payload = [0x00u8; 4]; // dummy payload
849        let mut data = Vec::new();
850        data.push(0x0A); // OBU header
851                         // LEB128 size = 4
852        data.push(0x04);
853        data.extend_from_slice(&payload);
854
855        let obus = split_av1_obus(&data).unwrap();
856        assert_eq!(obus.len(), 1);
857        assert_eq!(obus[0].obu_type, Av1ObuType::SequenceHeader);
858        assert_eq!(obus[0].payload, payload);
859    }
860
861    #[test]
862    fn test_split_av1_obus_empty_error() {
863        let err = split_av1_obus(&[]).unwrap_err();
864        assert_eq!(err, BitstreamFilterError::EmptyPacket);
865    }
866
867    #[test]
868    fn test_split_av1_obus_multiple() {
869        // Two OBUs: temporal delimiter (type=2) + tile group (type=4).
870        // TD header: type=2 => 0b0_0010_0_1_0 = 0x12, size=0
871        // TG header: type=4 => 0b0_0100_0_1_0 = 0x22, size=2
872        let mut data = Vec::new();
873        data.push(0x12); // temporal delimiter
874        data.push(0x00); // size=0
875        data.push(0x22); // tile group
876        data.push(0x02); // size=2
877        data.push(0xAA);
878        data.push(0xBB);
879
880        let obus = split_av1_obus(&data).unwrap();
881        assert_eq!(obus.len(), 2);
882        assert_eq!(obus[0].obu_type, Av1ObuType::TemporalDelimiter);
883        assert_eq!(obus[1].obu_type, Av1ObuType::TileGroup);
884        assert_eq!(obus[1].payload, [0xAA, 0xBB]);
885    }
886
887    #[test]
888    fn test_leb128_multi_byte() {
889        // Value 300 = 0xAC 0x02 in LEB128.
890        let data = [0xACu8, 0x02];
891        let (val, consumed) = read_leb128(&data, 0).unwrap();
892        assert_eq!(val, 300);
893        assert_eq!(consumed, 2);
894    }
895
896    #[test]
897    fn test_empty_packet_error() {
898        assert_eq!(
899            annexb_to_avcc(&[], LengthPrefixSize::Four).unwrap_err(),
900            BitstreamFilterError::EmptyPacket
901        );
902        assert_eq!(
903            avcc_to_annexb(&[], LengthPrefixSize::Four).unwrap_err(),
904            BitstreamFilterError::EmptyPacket
905        );
906    }
907
908    #[test]
909    fn test_length_prefix_size_from_raw() {
910        assert_eq!(
911            LengthPrefixSize::from_raw(1).unwrap(),
912            LengthPrefixSize::One
913        );
914        assert_eq!(
915            LengthPrefixSize::from_raw(2).unwrap(),
916            LengthPrefixSize::Two
917        );
918        assert_eq!(
919            LengthPrefixSize::from_raw(4).unwrap(),
920            LengthPrefixSize::Four
921        );
922        assert!(LengthPrefixSize::from_raw(3).is_err());
923    }
924}