Skip to main content

yscv_video/
codec.rs

1use super::error::VideoError;
2use super::frame::Rgb8Frame;
3
4/// Supported video codec identifiers.
5#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)]
6pub enum VideoCodec {
7    H264,
8    H265,
9    Raw,
10}
11
12/// Trait for video decoders that convert compressed NAL units to RGB8 frames.
13pub trait VideoDecoder: Send {
14    fn codec(&self) -> VideoCodec;
15
16    /// Decode a single compressed access unit into an RGB8 frame.
17    /// Returns `None` if the decoder needs more data (e.g., initial SPS/PPS).
18    fn decode(
19        &mut self,
20        data: &[u8],
21        timestamp_us: u64,
22    ) -> Result<Option<DecodedFrame>, VideoError>;
23
24    /// Flush any remaining buffered frames.
25    fn flush(&mut self) -> Result<Vec<DecodedFrame>, VideoError>;
26}
27
28/// Trait for video encoders that compress RGB8 frames.
29pub trait VideoEncoder: Send {
30    fn codec(&self) -> VideoCodec;
31
32    /// Encode one RGB8 frame, returning zero or more compressed packets.
33    fn encode(&mut self, frame: &Rgb8Frame) -> Result<Vec<EncodedPacket>, VideoError>;
34
35    /// Flush remaining buffered packets.
36    fn flush(&mut self) -> Result<Vec<EncodedPacket>, VideoError>;
37}
38
39/// A decoded video frame with metadata.
40#[derive(Debug, Clone)]
41pub struct DecodedFrame {
42    pub width: usize,
43    pub height: usize,
44    pub rgb8_data: Vec<u8>,
45    pub timestamp_us: u64,
46    pub keyframe: bool,
47}
48
49impl DecodedFrame {
50    pub fn into_rgb8_frame(self, frame_index: u64) -> Result<Rgb8Frame, VideoError> {
51        Rgb8Frame::from_bytes(
52            frame_index,
53            self.timestamp_us,
54            self.width,
55            self.height,
56            bytes::Bytes::from(self.rgb8_data),
57        )
58    }
59}
60
61/// A compressed video packet.
62#[derive(Debug, Clone)]
63pub struct EncodedPacket {
64    pub data: Vec<u8>,
65    pub timestamp_us: u64,
66    pub keyframe: bool,
67}
68
69// ── H.264 Annex B NAL unit parser ──────────────────────────────────
70
71/// H.264 NAL unit types.
72#[derive(Debug, Clone, Copy, PartialEq, Eq)]
73pub enum NalUnitType {
74    Slice,
75    SliceA,
76    SliceB,
77    SliceC,
78    Idr,
79    Sei,
80    Sps,
81    Pps,
82    Aud,
83    Other(u8),
84}
85
86impl NalUnitType {
87    pub fn from_byte(b: u8) -> Self {
88        match b & 0x1F {
89            1 => NalUnitType::Slice,
90            2 => NalUnitType::SliceA,
91            3 => NalUnitType::SliceB,
92            4 => NalUnitType::SliceC,
93            5 => NalUnitType::Idr,
94            6 => NalUnitType::Sei,
95            7 => NalUnitType::Sps,
96            8 => NalUnitType::Pps,
97            9 => NalUnitType::Aud,
98            other => NalUnitType::Other(other),
99        }
100    }
101
102    pub fn is_vcl(&self) -> bool {
103        matches!(
104            self,
105            NalUnitType::Slice
106                | NalUnitType::SliceA
107                | NalUnitType::SliceB
108                | NalUnitType::SliceC
109                | NalUnitType::Idr
110        )
111    }
112}
113
114/// A parsed NAL unit from an Annex B bitstream.
115#[derive(Debug, Clone)]
116pub struct NalUnit {
117    pub nal_type: NalUnitType,
118    pub nal_ref_idc: u8,
119    pub data: Vec<u8>,
120}
121
122/// Parses H.264 Annex B bitstream into NAL units.
123/// Splits on `0x000001` or `0x00000001` start codes.
124pub fn parse_annex_b(data: &[u8]) -> Vec<NalUnit> {
125    let mut units = Vec::new();
126    let mut i = 0;
127    let len = data.len();
128
129    while i < len {
130        if i + 3 <= len && data[i] == 0 && data[i + 1] == 0 {
131            let (start_code_len, found) = if i + 4 <= len && data[i + 2] == 0 && data[i + 3] == 1 {
132                (4, true)
133            } else if data[i + 2] == 1 {
134                (3, true)
135            } else {
136                (0, false)
137            };
138
139            if found {
140                let nal_start = i + start_code_len;
141                let mut nal_end = nal_start;
142                let mut j = nal_start;
143                while j < len {
144                    if j + 3 <= len
145                        && data[j] == 0
146                        && data[j + 1] == 0
147                        && ((j + 4 <= len && data[j + 2] == 0 && data[j + 3] == 1)
148                            || data[j + 2] == 1)
149                    {
150                        nal_end = j;
151                        break;
152                    }
153                    j += 1;
154                }
155                if j >= len {
156                    nal_end = len;
157                }
158
159                if nal_start < nal_end {
160                    let header = data[nal_start];
161                    let nal_ref_idc = (header >> 5) & 0x03;
162                    let nal_type = NalUnitType::from_byte(header);
163                    units.push(NalUnit {
164                        nal_type,
165                        nal_ref_idc,
166                        data: data[nal_start..nal_end].to_vec(),
167                    });
168                }
169                i = nal_end;
170                continue;
171            }
172        }
173        i += 1;
174    }
175
176    units
177}
178
179/// Extracts SPS and PPS NAL units from an Annex B bitstream.
180pub fn extract_parameter_sets(nals: &[NalUnit]) -> (Option<&NalUnit>, Option<&NalUnit>) {
181    let sps = nals.iter().find(|n| n.nal_type == NalUnitType::Sps);
182    let pps = nals.iter().find(|n| n.nal_type == NalUnitType::Pps);
183    (sps, pps)
184}
185
186// ── Simple MP4 box parser ──────────────────────────────────────────
187
188/// A parsed MP4 box header.
189#[derive(Debug, Clone)]
190pub struct Mp4Box {
191    pub box_type: [u8; 4],
192    pub size: u64,
193    pub header_size: u8,
194    pub offset: u64,
195}
196
197impl Mp4Box {
198    pub fn type_str(&self) -> &str {
199        std::str::from_utf8(&self.box_type).unwrap_or("????")
200    }
201}
202
203/// Parses top-level MP4 boxes from a byte buffer.
204pub fn parse_mp4_boxes(data: &[u8]) -> Result<Vec<Mp4Box>, VideoError> {
205    let mut boxes = Vec::new();
206    let mut offset = 0u64;
207    let len = data.len() as u64;
208
209    while offset + 8 <= len {
210        let o = offset as usize;
211        let size_32 = u32::from_be_bytes([data[o], data[o + 1], data[o + 2], data[o + 3]]);
212        let box_type = [data[o + 4], data[o + 5], data[o + 6], data[o + 7]];
213
214        let (size, header_size) = if size_32 == 1 {
215            if offset + 16 > len {
216                return Err(VideoError::ContainerParse(
217                    "truncated extended box size".into(),
218                ));
219            }
220            let extended = u64::from_be_bytes([
221                data[o + 8],
222                data[o + 9],
223                data[o + 10],
224                data[o + 11],
225                data[o + 12],
226                data[o + 13],
227                data[o + 14],
228                data[o + 15],
229            ]);
230            (extended, 16u8)
231        } else if size_32 == 0 {
232            (len - offset, 8u8)
233        } else {
234            (size_32 as u64, 8u8)
235        };
236
237        if size < header_size as u64 {
238            return Err(VideoError::ContainerParse(format!(
239                "box size {} smaller than header at offset {offset}",
240                size
241            )));
242        }
243
244        boxes.push(Mp4Box {
245            box_type,
246            size,
247            header_size,
248            offset,
249        });
250
251        offset += size;
252    }
253
254    Ok(boxes)
255}
256
257/// Finds a box by 4-char type code.
258pub fn find_box<'a>(boxes: &'a [Mp4Box], box_type: &[u8; 4]) -> Option<&'a Mp4Box> {
259    boxes.iter().find(|b| &b.box_type == box_type)
260}
261
262/// Parses child boxes inside a parent box.
263pub fn parse_child_boxes(data: &[u8], parent: &Mp4Box) -> Result<Vec<Mp4Box>, VideoError> {
264    let start = (parent.offset + parent.header_size as u64) as usize;
265    let end = (parent.offset + parent.size) as usize;
266    if end > data.len() || start >= end {
267        return Ok(Vec::new());
268    }
269    let child_data = &data[start..end];
270    let mut children = parse_mp4_boxes(child_data)?;
271    for child in &mut children {
272        child.offset += start as u64;
273    }
274    Ok(children)
275}