1use crate::{BitstreamReader, VideoError};
5
6#[derive(Debug, Clone, Copy, Default, PartialEq, Eq)]
12pub struct MotionVector {
13 pub dx: i16,
14 pub dy: i16,
15 pub ref_idx: usize,
16}
17
18pub fn parse_mvd(reader: &mut BitstreamReader) -> Result<(i16, i16), VideoError> {
20 let mvd_x = reader.read_se()?;
21 let mvd_y = reader.read_se()?;
22 Ok((mvd_x as i16, mvd_y as i16))
23}
24
25pub fn predict_mv(left: MotionVector, top: MotionVector, top_right: MotionVector) -> MotionVector {
27 MotionVector {
28 dx: median_of_three(left.dx, top.dx, top_right.dx),
29 dy: median_of_three(left.dy, top.dy, top_right.dy),
30 ref_idx: 0,
31 }
32}
33
34fn median_of_three(a: i16, b: i16, c: i16) -> i16 {
35 let mut arr = [a, b, c];
36 arr.sort();
37 arr[1]
38}
39
40#[allow(clippy::too_many_arguments)]
48pub fn motion_compensate_16x16(
49 reference: &[u8],
50 ref_width: usize,
51 ref_height: usize,
52 channels: usize,
53 mv: MotionVector,
54 mb_x: usize,
55 mb_y: usize,
56 output: &mut [u8],
57 out_width: usize,
58) {
59 let src_x = (mb_x * 16) as i32 + mv.dx as i32;
60 let src_y = (mb_y * 16) as i32 + mv.dy as i32;
61
62 for row in 0..16 {
63 for col in 0..16 {
64 let sy = (src_y + row as i32).clamp(0, ref_height as i32 - 1) as usize;
65 let sx = (src_x + col as i32).clamp(0, ref_width as i32 - 1) as usize;
66 let dst_y = mb_y * 16 + row;
67 let dst_x = mb_x * 16 + col;
68 for c in 0..channels {
69 let dst_idx = (dst_y * out_width + dst_x) * channels + c;
70 let src_idx = (sy * ref_width + sx) * channels + c;
71 if dst_idx < output.len() && src_idx < reference.len() {
72 output[dst_idx] = reference[src_idx];
73 }
74 }
75 }
76 }
77}
78
79#[allow(clippy::too_many_arguments)]
85pub fn motion_compensate_halfpel(
86 reference: &[u8],
87 ref_width: usize,
88 ref_height: usize,
89 channels: usize,
90 mv: MotionVector,
91 mb_x: usize,
92 mb_y: usize,
93 output: &mut [u8],
94 out_width: usize,
95) {
96 let base_x = (mb_x * 16) as i32 * 4 + mv.dx as i32;
97 let base_y = (mb_y * 16) as i32 * 4 + mv.dy as i32;
98
99 for row in 0..16 {
100 for col in 0..16 {
101 let qx = base_x + col as i32 * 4;
102 let qy = base_y + row as i32 * 4;
103
104 let ix = qx >> 2;
106 let iy = qy >> 2;
107 let fx = (qx & 3) as u16;
108 let fy = (qy & 3) as u16;
109
110 let hx = fx.div_ceil(2); let hy = fy.div_ceil(2);
113
114 let x0 = ix.clamp(0, ref_width as i32 - 1) as usize;
115 let y0 = iy.clamp(0, ref_height as i32 - 1) as usize;
116 let x1 = (ix + 1).clamp(0, ref_width as i32 - 1) as usize;
117 let y1 = (iy + 1).clamp(0, ref_height as i32 - 1) as usize;
118
119 let dst_y = mb_y * 16 + row;
120 let dst_x = mb_x * 16 + col;
121
122 for c in 0..channels {
123 let s00 = reference[(y0 * ref_width + x0) * channels + c] as u16;
124 let s10 = reference[(y0 * ref_width + x1) * channels + c] as u16;
125 let s01 = reference[(y1 * ref_width + x0) * channels + c] as u16;
126 let s11 = reference[(y1 * ref_width + x1) * channels + c] as u16;
127
128 let val = if hx == 0 && hy == 0 {
130 s00
131 } else if hx > 0 && hy == 0 {
132 (s00 + s10).div_ceil(2)
133 } else if hx == 0 && hy > 0 {
134 (s00 + s01).div_ceil(2)
135 } else {
136 (s00 + s10 + s01 + s11 + 2) / 4
137 };
138
139 let dst_idx = (dst_y * out_width + dst_x) * channels + c;
140 if dst_idx < output.len() {
141 output[dst_idx] = val as u8;
142 }
143 }
144 }
145 }
146}
147
148#[allow(clippy::too_many_arguments)]
158pub fn decode_p_macroblock(
159 reader: &mut BitstreamReader,
160 reference_frame: &[u8],
161 ref_width: usize,
162 ref_height: usize,
163 mb_x: usize,
164 mb_y: usize,
165 neighbor_mvs: &[MotionVector],
166 output: &mut [u8],
167 out_width: usize,
168) -> Result<MotionVector, VideoError> {
169 let _mb_type = reader.read_ue()?;
171
172 let (mvd_x, mvd_y) = parse_mvd(reader)?;
174
175 let predicted = predict_mv(
177 neighbor_mvs.first().copied().unwrap_or_default(),
178 neighbor_mvs.get(1).copied().unwrap_or_default(),
179 neighbor_mvs.get(2).copied().unwrap_or_default(),
180 );
181
182 let mv = MotionVector {
184 dx: predicted.dx + mvd_x,
185 dy: predicted.dy + mvd_y,
186 ref_idx: 0,
187 };
188
189 motion_compensate_16x16(
191 reference_frame,
192 ref_width,
193 ref_height,
194 3,
195 mv,
196 mb_x,
197 mb_y,
198 output,
199 out_width,
200 );
201
202 Ok(mv)
203}
204
205pub struct ReferenceFrameBuffer {
214 frames: Vec<Vec<u8>>,
215 max_refs: usize,
216}
217
218impl ReferenceFrameBuffer {
219 pub fn new(max_refs: usize) -> Self {
221 Self {
222 frames: Vec::new(),
223 max_refs,
224 }
225 }
226
227 pub fn push(&mut self, frame: Vec<u8>) {
230 if self.frames.len() >= self.max_refs {
231 self.frames.remove(0);
232 }
233 self.frames.push(frame);
234 }
235
236 pub fn get(&self, idx: usize) -> Option<&[u8]> {
238 self.frames.get(idx).map(|v| v.as_slice())
239 }
240
241 pub fn latest(&self) -> Option<&[u8]> {
243 self.frames.last().map(|v| v.as_slice())
244 }
245
246 pub fn len(&self) -> usize {
248 self.frames.len()
249 }
250
251 pub fn is_empty(&self) -> bool {
253 self.frames.is_empty()
254 }
255}
256
257#[cfg(test)]
262mod tests {
263 use super::*;
264
265 fn push_exp_golomb(bits: &mut Vec<u8>, value: u32) {
268 if value == 0 {
269 bits.push(1);
270 return;
271 }
272 let code = value + 1;
273 let bit_len = 32 - code.leading_zeros();
274 let leading_zeros = bit_len - 1;
275 for _ in 0..leading_zeros {
276 bits.push(0);
277 }
278 for i in (0..bit_len).rev() {
279 bits.push(((code >> i) & 1) as u8);
280 }
281 }
282
283 fn push_signed_exp_golomb(bits: &mut Vec<u8>, value: i32) {
284 let code = if value > 0 {
285 (2 * value - 1) as u32
286 } else if value < 0 {
287 (2 * (-value)) as u32
288 } else {
289 0
290 };
291 push_exp_golomb(bits, code);
292 }
293
294 fn bits_to_bytes(bits: &[u8]) -> Vec<u8> {
295 let mut bytes = Vec::new();
296 for chunk in bits.chunks(8) {
297 let mut byte = 0u8;
298 for (i, &bit) in chunk.iter().enumerate() {
299 byte |= bit << (7 - i);
300 }
301 bytes.push(byte);
302 }
303 bytes
304 }
305
306 #[test]
309 fn motion_vector_median_prediction() {
310 let left = MotionVector {
311 dx: 2,
312 dy: -4,
313 ref_idx: 0,
314 };
315 let top = MotionVector {
316 dx: 6,
317 dy: 1,
318 ref_idx: 0,
319 };
320 let top_right = MotionVector {
321 dx: -3,
322 dy: 8,
323 ref_idx: 0,
324 };
325
326 let pred = predict_mv(left, top, top_right);
327 assert_eq!(pred.dx, 2);
329 assert_eq!(pred.dy, 1);
330
331 let zero = MotionVector::default();
333 let pred_zero = predict_mv(zero, zero, zero);
334 assert_eq!(pred_zero.dx, 0);
335 assert_eq!(pred_zero.dy, 0);
336
337 let a = MotionVector {
339 dx: 5,
340 dy: 5,
341 ref_idx: 0,
342 };
343 let b = MotionVector {
344 dx: 5,
345 dy: 5,
346 ref_idx: 0,
347 };
348 let c = MotionVector {
349 dx: -10,
350 dy: 20,
351 ref_idx: 0,
352 };
353 let pred2 = predict_mv(a, b, c);
354 assert_eq!(pred2.dx, 5);
355 assert_eq!(pred2.dy, 5);
356 }
357
358 #[test]
361 fn motion_compensate_copies_block() {
362 let ref_w = 32;
364 let ref_h = 32;
365 let channels = 1;
366 let mut reference = vec![0u8; ref_w * ref_h * channels];
367 for row in 0..ref_h {
368 for col in 0..ref_w {
369 reference[row * ref_w + col] = row as u8;
370 }
371 }
372
373 let mut output = vec![0u8; ref_w * ref_h * channels];
375 let mv = MotionVector {
376 dx: 0,
377 dy: 0,
378 ref_idx: 0,
379 };
380 motion_compensate_16x16(
381 &reference,
382 ref_w,
383 ref_h,
384 channels,
385 mv,
386 0,
387 0,
388 &mut output,
389 ref_w,
390 );
391
392 for row in 0..16 {
393 for col in 0..16 {
394 assert_eq!(
395 output[row * ref_w + col],
396 row as u8,
397 "mismatch at ({row}, {col})"
398 );
399 }
400 }
401
402 let mut output2 = vec![0u8; ref_w * ref_h * channels];
404 let mv2 = MotionVector {
405 dx: 4,
406 dy: 2,
407 ref_idx: 0,
408 };
409 motion_compensate_16x16(
410 &reference,
411 ref_w,
412 ref_h,
413 channels,
414 mv2,
415 0,
416 0,
417 &mut output2,
418 ref_w,
419 );
420
421 for row in 0..16 {
422 let expected_src_y = (row as i32 + 2).clamp(0, ref_h as i32 - 1) as u8;
423 for col in 0..16 {
424 assert_eq!(
425 output2[row * ref_w + col],
426 expected_src_y,
427 "offset mismatch at ({row}, {col})"
428 );
429 }
430 }
431 }
432
433 #[test]
436 fn reference_frame_buffer_fifo() {
437 let mut buf = ReferenceFrameBuffer::new(3);
438 assert!(buf.is_empty());
439 assert_eq!(buf.len(), 0);
440 assert!(buf.latest().is_none());
441
442 buf.push(vec![1, 2, 3]);
443 buf.push(vec![4, 5, 6]);
444 buf.push(vec![7, 8, 9]);
445 assert_eq!(buf.len(), 3);
446 assert_eq!(buf.get(0), Some([1u8, 2, 3].as_slice()));
447 assert_eq!(buf.get(1), Some([4u8, 5, 6].as_slice()));
448 assert_eq!(buf.get(2), Some([7u8, 8, 9].as_slice()));
449 assert_eq!(buf.latest(), Some([7u8, 8, 9].as_slice()));
450
451 buf.push(vec![10, 11, 12]);
453 assert_eq!(buf.len(), 3);
454 assert_eq!(buf.get(0), Some([4u8, 5, 6].as_slice()));
455 assert_eq!(buf.latest(), Some([10u8, 11, 12].as_slice()));
456 assert!(buf.get(3).is_none());
457 }
458
459 #[test]
462 fn parse_mvd_roundtrip() {
463 let mut bits = Vec::new();
465 push_signed_exp_golomb(&mut bits, 3);
466 push_signed_exp_golomb(&mut bits, -5);
467 while bits.len() % 8 != 0 {
469 bits.push(0);
470 }
471 let bytes = bits_to_bytes(&bits);
472
473 let mut reader = BitstreamReader::new(&bytes);
474 let (mvd_x, mvd_y) = parse_mvd(&mut reader).unwrap();
475 assert_eq!(mvd_x, 3);
476 assert_eq!(mvd_y, -5);
477
478 let mut bits2 = Vec::new();
480 push_signed_exp_golomb(&mut bits2, 0);
481 push_signed_exp_golomb(&mut bits2, 0);
482 while bits2.len() % 8 != 0 {
483 bits2.push(0);
484 }
485 let bytes2 = bits_to_bytes(&bits2);
486
487 let mut reader2 = BitstreamReader::new(&bytes2);
488 let (mvd_x2, mvd_y2) = parse_mvd(&mut reader2).unwrap();
489 assert_eq!(mvd_x2, 0);
490 assert_eq!(mvd_y2, 0);
491 }
492}