Skip to main content

yscv_video/
h264_motion.rs

1//! H.264 P-slice motion compensation: motion vector parsing, prediction,
2//! reference frame buffering, and inter-frame block copy.
3
4use crate::{BitstreamReader, VideoError};
5
6// ---------------------------------------------------------------------------
7// Motion vector
8// ---------------------------------------------------------------------------
9
10/// Motion vector for a macroblock partition.
11#[derive(Debug, Clone, Copy, Default, PartialEq, Eq)]
12pub struct MotionVector {
13    pub dx: i16,
14    pub dy: i16,
15    pub ref_idx: usize,
16}
17
18/// Parse motion vector difference from bitstream (Exp-Golomb coded).
19pub fn parse_mvd(reader: &mut BitstreamReader) -> Result<(i16, i16), VideoError> {
20    let mvd_x = reader.read_se()?;
21    let mvd_y = reader.read_se()?;
22    Ok((mvd_x as i16, mvd_y as i16))
23}
24
25/// Predict motion vector from neighboring blocks (median prediction).
26pub fn predict_mv(left: MotionVector, top: MotionVector, top_right: MotionVector) -> MotionVector {
27    MotionVector {
28        dx: median_of_three(left.dx, top.dx, top_right.dx),
29        dy: median_of_three(left.dy, top.dy, top_right.dy),
30        ref_idx: 0,
31    }
32}
33
34fn median_of_three(a: i16, b: i16, c: i16) -> i16 {
35    let mut arr = [a, b, c];
36    arr.sort();
37    arr[1]
38}
39
40// ---------------------------------------------------------------------------
41// Motion compensation
42// ---------------------------------------------------------------------------
43
44/// Apply motion compensation: copy a 16x16 block from the reference frame with
45/// the given motion vector offset. Pixel coordinates that fall outside the
46/// reference frame are clamped to the nearest edge sample.
47#[allow(clippy::too_many_arguments)]
48pub fn motion_compensate_16x16(
49    reference: &[u8],
50    ref_width: usize,
51    ref_height: usize,
52    channels: usize,
53    mv: MotionVector,
54    mb_x: usize,
55    mb_y: usize,
56    output: &mut [u8],
57    out_width: usize,
58) {
59    let src_x = (mb_x * 16) as i32 + mv.dx as i32;
60    let src_y = (mb_y * 16) as i32 + mv.dy as i32;
61
62    for row in 0..16 {
63        for col in 0..16 {
64            let sy = (src_y + row as i32).clamp(0, ref_height as i32 - 1) as usize;
65            let sx = (src_x + col as i32).clamp(0, ref_width as i32 - 1) as usize;
66            let dst_y = mb_y * 16 + row;
67            let dst_x = mb_x * 16 + col;
68            for c in 0..channels {
69                let dst_idx = (dst_y * out_width + dst_x) * channels + c;
70                let src_idx = (sy * ref_width + sx) * channels + c;
71                if dst_idx < output.len() && src_idx < reference.len() {
72                    output[dst_idx] = reference[src_idx];
73                }
74            }
75        }
76    }
77}
78
79/// Apply half-pel interpolation for sub-pixel motion vectors.
80///
81/// Motion vectors are in quarter-pel units. Integer-pel positions are copied
82/// directly; fractional positions use bilinear interpolation between the two
83/// nearest integer samples (half-pel approximation).
84#[allow(clippy::too_many_arguments)]
85pub fn motion_compensate_halfpel(
86    reference: &[u8],
87    ref_width: usize,
88    ref_height: usize,
89    channels: usize,
90    mv: MotionVector,
91    mb_x: usize,
92    mb_y: usize,
93    output: &mut [u8],
94    out_width: usize,
95) {
96    let base_x = (mb_x * 16) as i32 * 4 + mv.dx as i32;
97    let base_y = (mb_y * 16) as i32 * 4 + mv.dy as i32;
98
99    for row in 0..16 {
100        for col in 0..16 {
101            let qx = base_x + col as i32 * 4;
102            let qy = base_y + row as i32 * 4;
103
104            // Integer sample position and fractional offset (0..3)
105            let ix = qx >> 2;
106            let iy = qy >> 2;
107            let fx = (qx & 3) as u16;
108            let fy = (qy & 3) as u16;
109
110            // Round quarter-pel to half-pel grid (0, 2, or snap to nearest)
111            let hx = fx.div_ceil(2); // 0->0, 1->1, 2->1, 3->2 but we only use 0 or 1
112            let hy = fy.div_ceil(2);
113
114            let x0 = ix.clamp(0, ref_width as i32 - 1) as usize;
115            let y0 = iy.clamp(0, ref_height as i32 - 1) as usize;
116            let x1 = (ix + 1).clamp(0, ref_width as i32 - 1) as usize;
117            let y1 = (iy + 1).clamp(0, ref_height as i32 - 1) as usize;
118
119            let dst_y = mb_y * 16 + row;
120            let dst_x = mb_x * 16 + col;
121
122            for c in 0..channels {
123                let s00 = reference[(y0 * ref_width + x0) * channels + c] as u16;
124                let s10 = reference[(y0 * ref_width + x1) * channels + c] as u16;
125                let s01 = reference[(y1 * ref_width + x0) * channels + c] as u16;
126                let s11 = reference[(y1 * ref_width + x1) * channels + c] as u16;
127
128                // Bilinear blend: weight = hx, hy in {0, 1, 2} mapped to 0..2
129                let val = if hx == 0 && hy == 0 {
130                    s00
131                } else if hx > 0 && hy == 0 {
132                    (s00 + s10).div_ceil(2)
133                } else if hx == 0 && hy > 0 {
134                    (s00 + s01).div_ceil(2)
135                } else {
136                    (s00 + s10 + s01 + s11 + 2) / 4
137                };
138
139                let dst_idx = (dst_y * out_width + dst_x) * channels + c;
140                if dst_idx < output.len() {
141                    output[dst_idx] = val as u8;
142                }
143            }
144        }
145    }
146}
147
148// ---------------------------------------------------------------------------
149// P-slice macroblock decoder
150// ---------------------------------------------------------------------------
151
152/// Decode a P-slice macroblock: parse mb_type, motion vectors, and apply
153/// motion compensation from the reference frame.
154///
155/// Returns the decoded motion vector so the caller can store it for
156/// neighboring-block prediction of subsequent macroblocks.
157#[allow(clippy::too_many_arguments)]
158pub fn decode_p_macroblock(
159    reader: &mut BitstreamReader,
160    reference_frame: &[u8],
161    ref_width: usize,
162    ref_height: usize,
163    mb_x: usize,
164    mb_y: usize,
165    neighbor_mvs: &[MotionVector],
166    output: &mut [u8],
167    out_width: usize,
168) -> Result<MotionVector, VideoError> {
169    // 1. Parse mb_type (P_L0_16x16 = 0 for simplest case)
170    let _mb_type = reader.read_ue()?;
171
172    // 2. Parse motion vector difference
173    let (mvd_x, mvd_y) = parse_mvd(reader)?;
174
175    // 3. Predict MV from neighbors (left, top, top-right)
176    let predicted = predict_mv(
177        neighbor_mvs.first().copied().unwrap_or_default(),
178        neighbor_mvs.get(1).copied().unwrap_or_default(),
179        neighbor_mvs.get(2).copied().unwrap_or_default(),
180    );
181
182    // 4. Final MV = predicted + difference
183    let mv = MotionVector {
184        dx: predicted.dx + mvd_x,
185        dy: predicted.dy + mvd_y,
186        ref_idx: 0,
187    };
188
189    // 5. Motion compensate (integer-pel for P_L0_16x16)
190    motion_compensate_16x16(
191        reference_frame,
192        ref_width,
193        ref_height,
194        3,
195        mv,
196        mb_x,
197        mb_y,
198        output,
199        out_width,
200    );
201
202    Ok(mv)
203}
204
205// ---------------------------------------------------------------------------
206// Reference frame buffer
207// ---------------------------------------------------------------------------
208
209/// Simple reference frame buffer for P-slice decoding.
210///
211/// Maintains a bounded FIFO of recent reconstructed frames so that P-slices
212/// can reference them for motion compensation.
213pub struct ReferenceFrameBuffer {
214    frames: Vec<Vec<u8>>,
215    max_refs: usize,
216}
217
218impl ReferenceFrameBuffer {
219    /// Creates a new buffer that keeps at most `max_refs` reference frames.
220    pub fn new(max_refs: usize) -> Self {
221        Self {
222            frames: Vec::new(),
223            max_refs,
224        }
225    }
226
227    /// Pushes a reconstructed frame into the buffer, evicting the oldest if
228    /// the capacity is exceeded.
229    pub fn push(&mut self, frame: Vec<u8>) {
230        if self.frames.len() >= self.max_refs {
231            self.frames.remove(0);
232        }
233        self.frames.push(frame);
234    }
235
236    /// Returns the reference frame at `idx` (0 = oldest retained frame).
237    pub fn get(&self, idx: usize) -> Option<&[u8]> {
238        self.frames.get(idx).map(|v| v.as_slice())
239    }
240
241    /// Returns the most recently pushed reference frame.
242    pub fn latest(&self) -> Option<&[u8]> {
243        self.frames.last().map(|v| v.as_slice())
244    }
245
246    /// Returns the number of reference frames currently stored.
247    pub fn len(&self) -> usize {
248        self.frames.len()
249    }
250
251    /// Returns `true` if the buffer contains no reference frames.
252    pub fn is_empty(&self) -> bool {
253        self.frames.is_empty()
254    }
255}
256
257// ---------------------------------------------------------------------------
258// Tests
259// ---------------------------------------------------------------------------
260
261#[cfg(test)]
262mod tests {
263    use super::*;
264
265    // -- Test helpers (same convention as h264_decoder.rs) -------------------
266
267    fn push_exp_golomb(bits: &mut Vec<u8>, value: u32) {
268        if value == 0 {
269            bits.push(1);
270            return;
271        }
272        let code = value + 1;
273        let bit_len = 32 - code.leading_zeros();
274        let leading_zeros = bit_len - 1;
275        for _ in 0..leading_zeros {
276            bits.push(0);
277        }
278        for i in (0..bit_len).rev() {
279            bits.push(((code >> i) & 1) as u8);
280        }
281    }
282
283    fn push_signed_exp_golomb(bits: &mut Vec<u8>, value: i32) {
284        let code = if value > 0 {
285            (2 * value - 1) as u32
286        } else if value < 0 {
287            (2 * (-value)) as u32
288        } else {
289            0
290        };
291        push_exp_golomb(bits, code);
292    }
293
294    fn bits_to_bytes(bits: &[u8]) -> Vec<u8> {
295        let mut bytes = Vec::new();
296        for chunk in bits.chunks(8) {
297            let mut byte = 0u8;
298            for (i, &bit) in chunk.iter().enumerate() {
299                byte |= bit << (7 - i);
300            }
301            bytes.push(byte);
302        }
303        bytes
304    }
305
306    // -- 1. Median prediction -----------------------------------------------
307
308    #[test]
309    fn motion_vector_median_prediction() {
310        let left = MotionVector {
311            dx: 2,
312            dy: -4,
313            ref_idx: 0,
314        };
315        let top = MotionVector {
316            dx: 6,
317            dy: 1,
318            ref_idx: 0,
319        };
320        let top_right = MotionVector {
321            dx: -3,
322            dy: 8,
323            ref_idx: 0,
324        };
325
326        let pred = predict_mv(left, top, top_right);
327        // median(2, 6, -3) = 2; median(-4, 1, 8) = 1
328        assert_eq!(pred.dx, 2);
329        assert_eq!(pred.dy, 1);
330
331        // All-zero neighbors
332        let zero = MotionVector::default();
333        let pred_zero = predict_mv(zero, zero, zero);
334        assert_eq!(pred_zero.dx, 0);
335        assert_eq!(pred_zero.dy, 0);
336
337        // Two equal, one different
338        let a = MotionVector {
339            dx: 5,
340            dy: 5,
341            ref_idx: 0,
342        };
343        let b = MotionVector {
344            dx: 5,
345            dy: 5,
346            ref_idx: 0,
347        };
348        let c = MotionVector {
349            dx: -10,
350            dy: 20,
351            ref_idx: 0,
352        };
353        let pred2 = predict_mv(a, b, c);
354        assert_eq!(pred2.dx, 5);
355        assert_eq!(pred2.dy, 5);
356    }
357
358    // -- 2. Motion compensation block copy ----------------------------------
359
360    #[test]
361    fn motion_compensate_copies_block() {
362        // 32x32 reference, 1 channel, filled with row index as pixel value
363        let ref_w = 32;
364        let ref_h = 32;
365        let channels = 1;
366        let mut reference = vec![0u8; ref_w * ref_h * channels];
367        for row in 0..ref_h {
368            for col in 0..ref_w {
369                reference[row * ref_w + col] = row as u8;
370            }
371        }
372
373        // MB at (0, 0) with mv=(0, 0) should copy top-left 16x16
374        let mut output = vec![0u8; ref_w * ref_h * channels];
375        let mv = MotionVector {
376            dx: 0,
377            dy: 0,
378            ref_idx: 0,
379        };
380        motion_compensate_16x16(
381            &reference,
382            ref_w,
383            ref_h,
384            channels,
385            mv,
386            0,
387            0,
388            &mut output,
389            ref_w,
390        );
391
392        for row in 0..16 {
393            for col in 0..16 {
394                assert_eq!(
395                    output[row * ref_w + col],
396                    row as u8,
397                    "mismatch at ({row}, {col})"
398                );
399            }
400        }
401
402        // MB at (0, 0) with mv=(4, 2) should copy from (4, 2)
403        let mut output2 = vec![0u8; ref_w * ref_h * channels];
404        let mv2 = MotionVector {
405            dx: 4,
406            dy: 2,
407            ref_idx: 0,
408        };
409        motion_compensate_16x16(
410            &reference,
411            ref_w,
412            ref_h,
413            channels,
414            mv2,
415            0,
416            0,
417            &mut output2,
418            ref_w,
419        );
420
421        for row in 0..16 {
422            let expected_src_y = (row as i32 + 2).clamp(0, ref_h as i32 - 1) as u8;
423            for col in 0..16 {
424                assert_eq!(
425                    output2[row * ref_w + col],
426                    expected_src_y,
427                    "offset mismatch at ({row}, {col})"
428                );
429            }
430        }
431    }
432
433    // -- 3. Reference frame buffer FIFO -------------------------------------
434
435    #[test]
436    fn reference_frame_buffer_fifo() {
437        let mut buf = ReferenceFrameBuffer::new(3);
438        assert!(buf.is_empty());
439        assert_eq!(buf.len(), 0);
440        assert!(buf.latest().is_none());
441
442        buf.push(vec![1, 2, 3]);
443        buf.push(vec![4, 5, 6]);
444        buf.push(vec![7, 8, 9]);
445        assert_eq!(buf.len(), 3);
446        assert_eq!(buf.get(0), Some([1u8, 2, 3].as_slice()));
447        assert_eq!(buf.get(1), Some([4u8, 5, 6].as_slice()));
448        assert_eq!(buf.get(2), Some([7u8, 8, 9].as_slice()));
449        assert_eq!(buf.latest(), Some([7u8, 8, 9].as_slice()));
450
451        // Pushing a 4th frame evicts the oldest
452        buf.push(vec![10, 11, 12]);
453        assert_eq!(buf.len(), 3);
454        assert_eq!(buf.get(0), Some([4u8, 5, 6].as_slice()));
455        assert_eq!(buf.latest(), Some([10u8, 11, 12].as_slice()));
456        assert!(buf.get(3).is_none());
457    }
458
459    // -- 4. parse_mvd roundtrip ---------------------------------------------
460
461    #[test]
462    fn parse_mvd_roundtrip() {
463        // Encode se(3) and se(-5) into a bitstream, then parse them back.
464        let mut bits = Vec::new();
465        push_signed_exp_golomb(&mut bits, 3);
466        push_signed_exp_golomb(&mut bits, -5);
467        // Pad to byte boundary
468        while bits.len() % 8 != 0 {
469            bits.push(0);
470        }
471        let bytes = bits_to_bytes(&bits);
472
473        let mut reader = BitstreamReader::new(&bytes);
474        let (mvd_x, mvd_y) = parse_mvd(&mut reader).unwrap();
475        assert_eq!(mvd_x, 3);
476        assert_eq!(mvd_y, -5);
477
478        // Zero MVD
479        let mut bits2 = Vec::new();
480        push_signed_exp_golomb(&mut bits2, 0);
481        push_signed_exp_golomb(&mut bits2, 0);
482        while bits2.len() % 8 != 0 {
483            bits2.push(0);
484        }
485        let bytes2 = bits_to_bytes(&bits2);
486
487        let mut reader2 = BitstreamReader::new(&bytes2);
488        let (mvd_x2, mvd_y2) = parse_mvd(&mut reader2).unwrap();
489        assert_eq!(mvd_x2, 0);
490        assert_eq!(mvd_y2, 0);
491    }
492}