Skip to main content

yscv_video/
hevc_inter.rs

1//! HEVC inter prediction: reference picture management, motion compensation,
2//! merge mode, AMVP, and inter CU parsing (ITU-T H.265 sections 8.5, 8.5.3).
3//!
4//! This module enables P-slice and B-slice decoding by providing:
5//! - [`HevcDpb`] — Decoded Picture Buffer for reference picture management.
6//! - [`HevcMv`] / [`HevcMvField`] — Quarter-pel motion vector types.
7//! - Luma motion compensation with 8-tap interpolation (Table 8-4).
8//! - Bi-prediction averaging.
9//! - Merge candidate list construction (spatial candidates).
10//! - AMVP candidate list construction.
11//! - MVD parsing from CABAC.
12//! - Inter CU prediction data parsing.
13
14use super::hevc_cabac::CabacDecoder;
15use super::hevc_decoder::{HevcSliceType, HevcSps};
16use super::hevc_syntax::HevcSliceCabacState;
17
18// ---------------------------------------------------------------------------
19// Reference picture
20// ---------------------------------------------------------------------------
21
22/// A reconstructed reference picture stored in the DPB.
23#[derive(Debug, Clone)]
24pub struct HevcReferencePicture {
25    /// Picture Order Count.
26    pub poc: i32,
27    /// Reconstructed luma samples (row-major, `width * height` elements).
28    pub luma: Vec<u8>,
29    /// Picture width in luma samples.
30    pub width: usize,
31    /// Picture height in luma samples.
32    pub height: usize,
33    /// Whether this is a long-term reference picture.
34    pub is_long_term: bool,
35}
36
37// ---------------------------------------------------------------------------
38// Decoded Picture Buffer (DPB)
39// ---------------------------------------------------------------------------
40
41/// Decoded Picture Buffer — manages reference pictures for inter prediction.
42///
43/// The maximum size is derived from `sps_max_dec_pic_buffering_minus1 + 1`.
44#[derive(Debug)]
45pub struct HevcDpb {
46    pictures: Vec<HevcReferencePicture>,
47    max_size: usize,
48}
49
50impl HevcDpb {
51    /// Create a new DPB that holds at most `max_size` reference pictures.
52    pub fn new(max_size: usize) -> Self {
53        Self {
54            pictures: Vec::new(),
55            max_size: max_size.max(1),
56        }
57    }
58
59    /// Insert a reference picture into the DPB.
60    ///
61    /// If the buffer is at capacity the oldest picture is bumped first.
62    pub fn add(&mut self, pic: HevcReferencePicture) {
63        if self.pictures.len() >= self.max_size {
64            self.bump();
65        }
66        self.pictures.push(pic);
67    }
68
69    /// Look up a reference picture by its POC.
70    pub fn get_by_poc(&self, poc: i32) -> Option<&HevcReferencePicture> {
71        self.pictures.iter().find(|p| p.poc == poc)
72    }
73
74    /// Mark a picture as unused by setting its long-term flag to false and
75    /// removing it from the buffer.
76    pub fn mark_unused(&mut self, poc: i32) {
77        self.pictures.retain(|p| p.poc != poc);
78    }
79
80    /// Remove the oldest picture (lowest POC) to make room.
81    pub fn bump(&mut self) {
82        if self.pictures.is_empty() {
83            return;
84        }
85        // Find the picture with the smallest POC.
86        let min_idx = self
87            .pictures
88            .iter()
89            .enumerate()
90            .min_by_key(|(_, p)| p.poc)
91            .map(|(i, _)| i)
92            .unwrap_or(0);
93        self.pictures.remove(min_idx);
94    }
95
96    /// Remove all pictures from the DPB (IDR flush).
97    pub fn clear(&mut self) {
98        self.pictures.clear();
99    }
100
101    /// Number of pictures currently in the buffer.
102    pub fn len(&self) -> usize {
103        self.pictures.len()
104    }
105
106    /// Returns `true` when the buffer is empty.
107    pub fn is_empty(&self) -> bool {
108        self.pictures.is_empty()
109    }
110
111    /// Maximum capacity.
112    pub fn max_size(&self) -> usize {
113        self.max_size
114    }
115}
116
117// ---------------------------------------------------------------------------
118// Motion vector types
119// ---------------------------------------------------------------------------
120
121/// Quarter-pel motion vector (14-bit range + sign per component).
122#[derive(Clone, Copy, Debug, Default, PartialEq, Eq)]
123pub struct HevcMv {
124    /// Horizontal displacement in quarter-pel units.
125    pub x: i16,
126    /// Vertical displacement in quarter-pel units.
127    pub y: i16,
128}
129
130impl HevcMv {
131    /// Create a motion vector from integer-pel coordinates (internally scaled
132    /// to quarter-pel).
133    pub fn from_fullpel(x: i16, y: i16) -> Self {
134        Self { x: x * 4, y: y * 4 }
135    }
136
137    /// Add two motion vectors (e.g. predictor + difference).
138    pub fn add(self, other: HevcMv) -> HevcMv {
139        HevcMv {
140            x: self.x.saturating_add(other.x),
141            y: self.y.saturating_add(other.y),
142        }
143    }
144
145    /// Negate both components.
146    pub fn negate(self) -> HevcMv {
147        HevcMv {
148            x: self.x.saturating_neg(),
149            y: self.y.saturating_neg(),
150        }
151    }
152}
153
154/// Per-PU motion information for L0 and L1 reference lists.
155#[derive(Clone, Copy, Debug, Default, PartialEq, Eq)]
156pub struct HevcMvField {
157    /// Motion vectors for L0 and L1.
158    pub mv: [HevcMv; 2],
159    /// Reference picture indices for L0 and L1 (-1 = unused).
160    pub ref_idx: [i8; 2],
161    /// Whether L0 / L1 prediction is active.
162    pub pred_flag: [bool; 2],
163}
164
165impl HevcMvField {
166    /// An empty (no-prediction) field.
167    pub fn unavailable() -> Self {
168        Self {
169            mv: [HevcMv::default(); 2],
170            ref_idx: [-1, -1],
171            pred_flag: [false, false],
172        }
173    }
174
175    /// Returns `true` when at least one list is active.
176    pub fn is_available(&self) -> bool {
177        self.pred_flag[0] || self.pred_flag[1]
178    }
179}
180
181// ---------------------------------------------------------------------------
182// Luma interpolation filter (HEVC spec Table 8-4)
183// ---------------------------------------------------------------------------
184
185/// 8-tap luma interpolation filter coefficients.
186///
187/// Indexed by fractional position (0 = integer, 1 = quarter, 2 = half,
188/// 3 = three-quarter).
189const HEVC_LUMA_FILTER: [[i16; 8]; 4] = [
190    [0, 0, 0, 64, 0, 0, 0, 0],        // Integer pel
191    [-1, 4, -10, 58, 17, -5, 1, 0],   // Quarter pel
192    [-1, 4, -11, 40, 40, -11, 4, -1], // Half pel
193    [0, 1, -5, 17, 58, -10, 4, -1],   // Three-quarter pel
194];
195
196// ---------------------------------------------------------------------------
197// Motion compensation
198// ---------------------------------------------------------------------------
199
200/// Fetch a reference sample with edge clamping.
201#[inline]
202fn ref_sample(pic: &HevcReferencePicture, x: i32, y: i32) -> i16 {
203    let cx = x.clamp(0, pic.width as i32 - 1) as usize;
204    let cy = y.clamp(0, pic.height as i32 - 1) as usize;
205    pic.luma[cy * pic.width + cx] as i16
206}
207
208/// Luma motion compensation with quarter-pel interpolation (8-tap filter).
209///
210/// `x`, `y` are the block position in the current picture (full-pel).
211/// `mv` is in quarter-pel units.
212/// `output` receives `block_w * block_h` samples as `i16` (for bi-pred
213/// averaging before final clipping).
214#[allow(clippy::too_many_arguments)]
215pub fn hevc_mc_luma(
216    ref_pic: &HevcReferencePicture,
217    x: i32,
218    y: i32,
219    mv: HevcMv,
220    block_w: usize,
221    block_h: usize,
222    output: &mut [i16],
223) {
224    debug_assert!(output.len() >= block_w * block_h);
225
226    let frac_x = ((mv.x as i32) & 3) as usize;
227    let frac_y = ((mv.y as i32) & 3) as usize;
228    let int_x = x + (mv.x as i32 >> 2);
229    let int_y = y + (mv.y as i32 >> 2);
230
231    let filter_h = &HEVC_LUMA_FILTER[frac_x];
232    let filter_v = &HEVC_LUMA_FILTER[frac_y];
233
234    if frac_x == 0 && frac_y == 0 {
235        // Integer-pel: direct copy.
236        for row in 0..block_h {
237            for col in 0..block_w {
238                output[row * block_w + col] =
239                    ref_sample(ref_pic, int_x + col as i32, int_y + row as i32);
240            }
241        }
242        return;
243    }
244
245    if frac_y == 0 {
246        // Horizontal-only filtering.
247        for row in 0..block_h {
248            for col in 0..block_w {
249                let sy = int_y + row as i32;
250                let mut sum = 0i32;
251                for k in 0..8 {
252                    let sx = int_x + col as i32 + k as i32 - 3;
253                    sum += ref_sample(ref_pic, sx, sy) as i32 * filter_h[k] as i32;
254                }
255                output[row * block_w + col] = ((sum + 32) >> 6) as i16;
256            }
257        }
258        return;
259    }
260
261    if frac_x == 0 {
262        // Vertical-only filtering.
263        for row in 0..block_h {
264            for col in 0..block_w {
265                let sx = int_x + col as i32;
266                let mut sum = 0i32;
267                for k in 0..8 {
268                    let sy = int_y + row as i32 + k as i32 - 3;
269                    sum += ref_sample(ref_pic, sx, sy) as i32 * filter_v[k] as i32;
270                }
271                output[row * block_w + col] = ((sum + 32) >> 6) as i16;
272            }
273        }
274        return;
275    }
276
277    // Both horizontal and vertical fractional position: two-pass filtering.
278    // First filter horizontally into a temporary buffer (with extra rows for
279    // the vertical filter tap extent), then filter vertically.
280    let ext_h = block_h + 7; // 3 rows above + 4 rows below
281    let mut tmp = vec![0i32; ext_h * block_w];
282
283    // Horizontal pass: produce (block_h + 7) rows of block_w intermediate
284    // samples shifted by 6 bits.
285    for row in 0..ext_h {
286        let sy = int_y + row as i32 - 3;
287        for col in 0..block_w {
288            let mut sum = 0i32;
289            for k in 0..8 {
290                let sx = int_x + col as i32 + k as i32 - 3;
291                sum += ref_sample(ref_pic, sx, sy) as i32 * filter_h[k] as i32;
292            }
293            tmp[row * block_w + col] = sum; // keep <<6 headroom
294        }
295    }
296
297    // Vertical pass over the intermediate buffer.
298    for row in 0..block_h {
299        for col in 0..block_w {
300            let mut sum = 0i64;
301            for k in 0..8 {
302                sum += tmp[(row + k) * block_w + col] as i64 * filter_v[k] as i64;
303            }
304            // Two shifts: +32 from h-pass kept, +2048 for v-pass rounding.
305            output[row * block_w + col] = ((sum + 2048) >> 12) as i16;
306        }
307    }
308}
309
310/// Bi-prediction average: combine L0 and L1 predictions and clip to [0, 255].
311///
312/// `pred_l0` and `pred_l1` are intermediate i16 values produced by
313/// [`hevc_mc_luma`].  The final result in `output` is clipped to u8.
314pub fn hevc_bipred_average(pred_l0: &[i16], pred_l1: &[i16], output: &mut [u8], size: usize) {
315    debug_assert!(pred_l0.len() >= size);
316    debug_assert!(pred_l1.len() >= size);
317    debug_assert!(output.len() >= size);
318
319    for i in 0..size {
320        let avg = (pred_l0[i] as i32 + pred_l1[i] as i32 + 1) >> 1;
321        output[i] = avg.clamp(0, 255) as u8;
322    }
323}
324
325/// Uni-prediction clip: convert an intermediate i16 buffer to u8.
326pub fn hevc_unipred_clip(pred: &[i16], output: &mut [u8], size: usize) {
327    debug_assert!(pred.len() >= size);
328    debug_assert!(output.len() >= size);
329    for i in 0..size {
330        output[i] = (pred[i] as i32).clamp(0, 255) as u8;
331    }
332}
333
334// ---------------------------------------------------------------------------
335// Merge mode (ITU-T H.265, 8.5.3.2)
336// ---------------------------------------------------------------------------
337
338/// Minimum PU size in luma samples (4x4 in HEVC).
339const MIN_PU_SIZE: usize = 4;
340
341/// Build the merge candidate list from spatial neighbours (simplified).
342///
343/// Full HEVC spec constructs up to 5 spatial candidates (A0, A1, B0, B1, B2),
344/// one temporal candidate, and combined bi-pred candidates.  This
345/// implementation covers the spatial candidates.
346///
347/// `mv_field` is the per-min-PU motion field of the current picture,
348/// dimensioned `pic_width_in_min_pu * pic_height_in_min_pu`.
349pub fn build_merge_candidates(
350    mv_field: &[HevcMvField],
351    pic_width_in_min_pu: usize,
352    x: usize,
353    y: usize,
354    block_w: usize,
355    block_h: usize,
356) -> Vec<HevcMvField> {
357    let mut candidates: Vec<HevcMvField> = Vec::with_capacity(5);
358    let max_candidates = 5usize;
359
360    let pu_x = x / MIN_PU_SIZE;
361    let pu_y = y / MIN_PU_SIZE;
362    let pu_w = block_w / MIN_PU_SIZE;
363    let pu_h = block_h / MIN_PU_SIZE;
364
365    // Helper to fetch a candidate if available.
366    let get = |px: usize, py: usize| -> Option<HevcMvField> {
367        if px < pic_width_in_min_pu {
368            let idx = py * pic_width_in_min_pu + px;
369            if idx < mv_field.len() {
370                let f = mv_field[idx];
371                if f.is_available() {
372                    return Some(f);
373                }
374            }
375        }
376        None
377    };
378
379    // A1: left-bottom  — (x - 1, y + block_h - 1)
380    if pu_x > 0
381        && let Some(c) = get(pu_x - 1, pu_y + pu_h - 1)
382    {
383        candidates.push(c);
384    }
385
386    // B1: top-right   — (x + block_w - 1, y - 1)
387    if candidates.len() < max_candidates
388        && pu_y > 0
389        && let Some(c) = get(pu_x + pu_w - 1, pu_y - 1)
390        && (candidates.is_empty() || candidates[candidates.len() - 1] != c)
391    {
392        candidates.push(c);
393    }
394
395    // B0: top-right+1 — (x + block_w, y - 1)
396    if candidates.len() < max_candidates
397        && pu_y > 0
398        && let Some(c) = get(pu_x + pu_w, pu_y - 1)
399        && candidates.last() != Some(&c)
400    {
401        candidates.push(c);
402    }
403
404    // A0: left-bottom+1 — (x - 1, y + block_h)
405    if candidates.len() < max_candidates
406        && pu_x > 0
407        && let Some(c) = get(pu_x - 1, pu_y + pu_h)
408        && candidates.last() != Some(&c)
409    {
410        candidates.push(c);
411    }
412
413    // B2: top-left    — (x - 1, y - 1)
414    if candidates.len() < max_candidates
415        && pu_x > 0
416        && pu_y > 0
417        && let Some(c) = get(pu_x - 1, pu_y - 1)
418        && candidates.last() != Some(&c)
419    {
420        candidates.push(c);
421    }
422
423    // Pad with zero-MV fields up to max_candidates if fewer were found.
424    while candidates.len() < max_candidates {
425        candidates.push(HevcMvField {
426            mv: [HevcMv::default(); 2],
427            ref_idx: [0, -1],
428            pred_flag: [true, false],
429        });
430    }
431
432    candidates
433}
434
435/// Parse `merge_idx` from the CABAC bitstream (truncated unary, bypass coded).
436///
437/// Returns a value in `0 ..= max_merge_cand - 1`.
438pub fn parse_merge_idx(cabac: &mut CabacDecoder<'_>, max_merge_cand: u32) -> u32 {
439    if max_merge_cand <= 1 {
440        return 0;
441    }
442    // First bin is bypass-coded, remaining bins are bypass-coded truncated
443    // unary.
444    let mut idx = 0u32;
445    if cabac.decode_bypass() {
446        idx += 1;
447        while idx < max_merge_cand - 1 {
448            if cabac.decode_bypass() {
449                idx += 1;
450            } else {
451                break;
452            }
453        }
454    }
455    idx
456}
457
458// ---------------------------------------------------------------------------
459// AMVP (ITU-T H.265, 8.5.3.1)
460// ---------------------------------------------------------------------------
461
462/// Build the AMVP candidate list (2 candidates).
463///
464/// Spatial candidates are derived from left (A0/A1) and above (B0/B1/B2)
465/// neighbours.  When fewer than 2 are found, zero-MVs fill the rest.
466pub fn build_amvp_candidates(
467    mv_field: &[HevcMvField],
468    pic_width_in_min_pu: usize,
469    x: usize,
470    y: usize,
471    ref_idx: i8,
472    list: usize,
473) -> [HevcMv; 2] {
474    let mut cands = [HevcMv::default(); 2];
475    let mut count = 0usize;
476
477    let pu_x = x / MIN_PU_SIZE;
478    let pu_y = y / MIN_PU_SIZE;
479
480    let get = |px: usize, py: usize| -> Option<HevcMv> {
481        if px < pic_width_in_min_pu {
482            let idx = py * pic_width_in_min_pu + px;
483            if idx < mv_field.len() {
484                let f = &mv_field[idx];
485                if f.pred_flag[list] && f.ref_idx[list] == ref_idx {
486                    return Some(f.mv[list]);
487                }
488                // Try the other list with same ref_idx (scaling omitted).
489                let other = 1 - list;
490                if f.pred_flag[other] && f.ref_idx[other] == ref_idx {
491                    return Some(f.mv[other]);
492                }
493            }
494        }
495        None
496    };
497
498    // Left group: A0 then A1.
499    if pu_x > 0 {
500        if let Some(mv) = get(pu_x - 1, pu_y) {
501            cands[count] = mv;
502            count += 1;
503        } else if pu_y > 0
504            && let Some(mv) = get(pu_x - 1, pu_y - 1)
505        {
506            cands[count] = mv;
507            count += 1;
508        }
509    }
510
511    // Above group: B0 then B1 then B2.
512    if count < 2
513        && pu_y > 0
514        && let Some(mv) = get(pu_x, pu_y - 1)
515        && (count == 0 || cands[0] != mv)
516    {
517        cands[count] = mv;
518        count += 1;
519    }
520    if count < 2
521        && pu_x > 0
522        && pu_y > 0
523        && let Some(mv) = get(pu_x - 1, pu_y - 1)
524        && (count == 0 || cands[0] != mv)
525    {
526        cands[count] = mv;
527        // count += 1; (last candidate, value not read again)
528    }
529
530    // Zero-MV fill.
531    // (cands already defaults to zero.)
532
533    cands
534}
535
536// ---------------------------------------------------------------------------
537// MVD parsing (ITU-T H.265, 7.3.8.9)
538// ---------------------------------------------------------------------------
539
540/// Parse a motion vector difference from the CABAC bitstream.
541///
542/// Each component (x, y) is coded as:
543///   `abs_mvd_greater0_flag` (bypass) → if set, `abs_mvd_greater1_flag`
544///   (bypass) → if set, `abs_mvd_minus2` (Exp-Golomb order-1, bypass)
545///   → `mvd_sign_flag` (bypass).
546pub fn parse_mvd(cabac: &mut CabacDecoder<'_>) -> HevcMv {
547    let abs_x_gt0 = cabac.decode_bypass();
548    let abs_y_gt0 = cabac.decode_bypass();
549
550    let abs_x_gt1 = if abs_x_gt0 {
551        cabac.decode_bypass()
552    } else {
553        false
554    };
555    let abs_y_gt1 = if abs_y_gt0 {
556        cabac.decode_bypass()
557    } else {
558        false
559    };
560
561    let mut abs_x: i16 = 0;
562    if abs_x_gt0 {
563        abs_x = 1;
564        if abs_x_gt1 {
565            abs_x += 1 + cabac.decode_eg(1) as i16;
566        }
567    }
568
569    let mut abs_y: i16 = 0;
570    if abs_y_gt0 {
571        abs_y = 1;
572        if abs_y_gt1 {
573            abs_y += 1 + cabac.decode_eg(1) as i16;
574        }
575    }
576
577    let sign_x = if abs_x_gt0 {
578        cabac.decode_bypass()
579    } else {
580        false
581    };
582    let sign_y = if abs_y_gt0 {
583        cabac.decode_bypass()
584    } else {
585        false
586    };
587
588    HevcMv {
589        x: if sign_x { -abs_x } else { abs_x },
590        y: if sign_y { -abs_y } else { abs_y },
591    }
592}
593
594// ---------------------------------------------------------------------------
595// Inter CU prediction parsing
596// ---------------------------------------------------------------------------
597
598/// Parse inter prediction data for one CU from the CABAC bitstream.
599///
600/// Returns the resulting [`HevcMvField`] describing L0/L1 motion.
601///
602/// This is a simplified implementation handling merge mode and single-list
603/// explicit MV coding for P-slices, plus basic B-slice bi-prediction.
604#[allow(clippy::too_many_arguments)]
605pub fn parse_inter_prediction(
606    state: &mut HevcSliceCabacState<'_>,
607    sps: &HevcSps,
608    slice_type: HevcSliceType,
609    mv_field: &[HevcMvField],
610    pic_width_in_min_pu: usize,
611    x: usize,
612    y: usize,
613    cu_size: usize,
614) -> HevcMvField {
615    // merge_flag — decoded as bypass (simplified; spec uses context 0).
616    let merge_flag = state.cabac.decode_bypass();
617
618    if merge_flag {
619        // Merge mode: pick candidate.
620        let max_merge = 5u32;
621        let merge_idx = parse_merge_idx(&mut state.cabac, max_merge);
622        let candidates =
623            build_merge_candidates(mv_field, pic_width_in_min_pu, x, y, cu_size, cu_size);
624        let idx = (merge_idx as usize).min(candidates.len().saturating_sub(1));
625        return candidates[idx];
626    }
627
628    // Explicit MV coding.
629    match slice_type {
630        HevcSliceType::P => {
631            // ref_idx_l0 — bypass coded unary (simplified).
632            let ref_idx_l0 = if sps.num_short_term_ref_pic_sets > 1 {
633                let mut idx = 0i8;
634                while (idx as u8) < sps.num_short_term_ref_pic_sets.saturating_sub(1) {
635                    if state.cabac.decode_bypass() {
636                        idx += 1;
637                    } else {
638                        break;
639                    }
640                }
641                idx
642            } else {
643                0i8
644            };
645
646            let mvd = parse_mvd(&mut state.cabac);
647
648            // AMVP predictor.
649            let amvp = build_amvp_candidates(mv_field, pic_width_in_min_pu, x, y, ref_idx_l0, 0);
650            let mvp_flag = state.cabac.decode_bypass();
651            let predictor = if mvp_flag { amvp[1] } else { amvp[0] };
652
653            HevcMvField {
654                mv: [predictor.add(mvd), HevcMv::default()],
655                ref_idx: [ref_idx_l0, -1],
656                pred_flag: [true, false],
657            }
658        }
659        HevcSliceType::B => {
660            // Simplified B-slice: decode both L0 and L1.
661            let ref_idx_l0 = 0i8;
662            let ref_idx_l1 = 0i8;
663
664            let mvd_l0 = parse_mvd(&mut state.cabac);
665            let mvd_l1 = parse_mvd(&mut state.cabac);
666
667            let amvp_l0 = build_amvp_candidates(mv_field, pic_width_in_min_pu, x, y, ref_idx_l0, 0);
668            let amvp_l1 = build_amvp_candidates(mv_field, pic_width_in_min_pu, x, y, ref_idx_l1, 1);
669
670            let mvp0_flag = state.cabac.decode_bypass();
671            let mvp1_flag = state.cabac.decode_bypass();
672
673            let pred0 = if mvp0_flag { amvp_l0[1] } else { amvp_l0[0] };
674            let pred1 = if mvp1_flag { amvp_l1[1] } else { amvp_l1[0] };
675
676            HevcMvField {
677                mv: [pred0.add(mvd_l0), pred1.add(mvd_l1)],
678                ref_idx: [ref_idx_l0, ref_idx_l1],
679                pred_flag: [true, true],
680            }
681        }
682        HevcSliceType::I => {
683            // Should never reach inter prediction in an I-slice.
684            HevcMvField::unavailable()
685        }
686    }
687}
688
689// ---------------------------------------------------------------------------
690// Tests
691// ---------------------------------------------------------------------------
692
693#[cfg(test)]
694mod tests {
695    use super::*;
696
697    // -- DPB tests ----------------------------------------------------------
698
699    #[test]
700    fn dpb_new_is_empty() {
701        let dpb = HevcDpb::new(5);
702        assert!(dpb.is_empty());
703        assert_eq!(dpb.len(), 0);
704        assert_eq!(dpb.max_size(), 5);
705    }
706
707    #[test]
708    fn dpb_add_and_get_by_poc() {
709        let mut dpb = HevcDpb::new(4);
710        dpb.add(HevcReferencePicture {
711            poc: 0,
712            luma: vec![128; 16],
713            width: 4,
714            height: 4,
715            is_long_term: false,
716        });
717        dpb.add(HevcReferencePicture {
718            poc: 2,
719            luma: vec![200; 16],
720            width: 4,
721            height: 4,
722            is_long_term: false,
723        });
724        assert_eq!(dpb.len(), 2);
725        assert!(dpb.get_by_poc(0).is_some());
726        assert_eq!(dpb.get_by_poc(0).unwrap().luma[0], 128);
727        assert!(dpb.get_by_poc(2).is_some());
728        assert!(dpb.get_by_poc(99).is_none());
729    }
730
731    #[test]
732    fn dpb_bump_removes_oldest() {
733        let mut dpb = HevcDpb::new(3);
734        for poc in 0..3 {
735            dpb.add(HevcReferencePicture {
736                poc,
737                luma: vec![poc as u8; 16],
738                width: 4,
739                height: 4,
740                is_long_term: false,
741            });
742        }
743        assert_eq!(dpb.len(), 3);
744        // Adding a 4th triggers bump of the lowest POC (0).
745        dpb.add(HevcReferencePicture {
746            poc: 5,
747            luma: vec![5; 16],
748            width: 4,
749            height: 4,
750            is_long_term: false,
751        });
752        assert_eq!(dpb.len(), 3);
753        assert!(dpb.get_by_poc(0).is_none());
754        assert!(dpb.get_by_poc(1).is_some());
755        assert!(dpb.get_by_poc(5).is_some());
756    }
757
758    #[test]
759    fn dpb_mark_unused() {
760        let mut dpb = HevcDpb::new(4);
761        dpb.add(HevcReferencePicture {
762            poc: 10,
763            luma: vec![0; 16],
764            width: 4,
765            height: 4,
766            is_long_term: false,
767        });
768        dpb.add(HevcReferencePicture {
769            poc: 20,
770            luma: vec![0; 16],
771            width: 4,
772            height: 4,
773            is_long_term: false,
774        });
775        dpb.mark_unused(10);
776        assert_eq!(dpb.len(), 1);
777        assert!(dpb.get_by_poc(10).is_none());
778        assert!(dpb.get_by_poc(20).is_some());
779    }
780
781    #[test]
782    fn dpb_clear() {
783        let mut dpb = HevcDpb::new(4);
784        for poc in 0..4 {
785            dpb.add(HevcReferencePicture {
786                poc,
787                luma: vec![0; 16],
788                width: 4,
789                height: 4,
790                is_long_term: false,
791            });
792        }
793        assert_eq!(dpb.len(), 4);
794        dpb.clear();
795        assert!(dpb.is_empty());
796        assert_eq!(dpb.len(), 0);
797    }
798
799    // -- MV arithmetic tests ------------------------------------------------
800
801    #[test]
802    fn mv_from_fullpel() {
803        let mv = HevcMv::from_fullpel(3, -5);
804        assert_eq!(mv.x, 12);
805        assert_eq!(mv.y, -20);
806    }
807
808    #[test]
809    fn mv_add() {
810        let a = HevcMv { x: 10, y: -4 };
811        let b = HevcMv { x: -3, y: 7 };
812        let c = a.add(b);
813        assert_eq!(c.x, 7);
814        assert_eq!(c.y, 3);
815    }
816
817    #[test]
818    fn mv_negate() {
819        let mv = HevcMv { x: 5, y: -8 };
820        let neg = mv.negate();
821        assert_eq!(neg.x, -5);
822        assert_eq!(neg.y, 8);
823    }
824
825    #[test]
826    fn mv_default_is_zero() {
827        let mv = HevcMv::default();
828        assert_eq!(mv.x, 0);
829        assert_eq!(mv.y, 0);
830    }
831
832    // -- Luma interpolation tests -------------------------------------------
833
834    fn make_ref_pic(w: usize, h: usize, val: u8) -> HevcReferencePicture {
835        HevcReferencePicture {
836            poc: 0,
837            luma: vec![val; w * h],
838            width: w,
839            height: h,
840            is_long_term: false,
841        }
842    }
843
844    #[test]
845    fn mc_luma_integer_pel() {
846        let pic = make_ref_pic(16, 16, 100);
847        let mut out = vec![0i16; 4 * 4];
848        hevc_mc_luma(&pic, 2, 2, HevcMv { x: 0, y: 0 }, 4, 4, &mut out);
849        for v in &out {
850            assert_eq!(*v, 100);
851        }
852    }
853
854    #[test]
855    fn mc_luma_half_pel_uniform() {
856        // With a uniform reference, half-pel filtering should yield the same
857        // value (filter coefficients sum to 64).
858        let pic = make_ref_pic(32, 32, 80);
859        let mut out = vec![0i16; 8 * 8];
860        // frac_x = 2 (half-pel)
861        hevc_mc_luma(&pic, 4, 4, HevcMv { x: 2, y: 0 }, 8, 8, &mut out);
862        for v in &out {
863            assert_eq!(*v, 80);
864        }
865    }
866
867    #[test]
868    fn mc_luma_quarter_pel_uniform() {
869        let pic = make_ref_pic(32, 32, 60);
870        let mut out = vec![0i16; 4 * 4];
871        hevc_mc_luma(&pic, 4, 4, HevcMv { x: 1, y: 0 }, 4, 4, &mut out);
872        for v in &out {
873            assert_eq!(*v, 60);
874        }
875    }
876
877    #[test]
878    fn mc_luma_three_quarter_pel_uniform() {
879        let pic = make_ref_pic(32, 32, 120);
880        let mut out = vec![0i16; 4 * 4];
881        hevc_mc_luma(&pic, 4, 4, HevcMv { x: 3, y: 0 }, 4, 4, &mut out);
882        for v in &out {
883            assert_eq!(*v, 120);
884        }
885    }
886
887    #[test]
888    fn mc_luma_vertical_half_pel_uniform() {
889        let pic = make_ref_pic(32, 32, 90);
890        let mut out = vec![0i16; 4 * 4];
891        hevc_mc_luma(&pic, 4, 4, HevcMv { x: 0, y: 2 }, 4, 4, &mut out);
892        for v in &out {
893            assert_eq!(*v, 90);
894        }
895    }
896
897    #[test]
898    fn mc_luma_diagonal_half_pel_uniform() {
899        let pic = make_ref_pic(32, 32, 70);
900        let mut out = vec![0i16; 4 * 4];
901        hevc_mc_luma(&pic, 4, 4, HevcMv { x: 2, y: 2 }, 4, 4, &mut out);
902        for v in &out {
903            assert_eq!(*v, 70);
904        }
905    }
906
907    #[test]
908    fn mc_luma_gradient_horizontal() {
909        // Horizontal gradient: column c has value 10*c.
910        let w = 32usize;
911        let h = 16usize;
912        let mut luma = vec![0u8; w * h];
913        for row in 0..h {
914            for col in 0..w {
915                luma[row * w + col] = (col * 8).min(255) as u8;
916            }
917        }
918        let pic = HevcReferencePicture {
919            poc: 0,
920            luma,
921            width: w,
922            height: h,
923            is_long_term: false,
924        };
925        let mut out = vec![0i16; 4 * 4];
926        // Integer-pel fetch.
927        hevc_mc_luma(&pic, 4, 2, HevcMv { x: 0, y: 0 }, 4, 4, &mut out);
928        // First column should be col=4 => 32
929        assert_eq!(out[0], 32);
930        assert_eq!(out[1], 40);
931    }
932
933    // -- Bi-prediction averaging tests --------------------------------------
934
935    #[test]
936    fn bipred_average_uniform() {
937        let l0 = vec![100i16; 16];
938        let l1 = vec![200i16; 16];
939        let mut out = vec![0u8; 16];
940        hevc_bipred_average(&l0, &l1, &mut out, 16);
941        // (100 + 200 + 1) >> 1 = 150
942        for v in &out {
943            assert_eq!(*v, 150);
944        }
945    }
946
947    #[test]
948    fn bipred_average_clamping() {
949        let l0 = vec![255i16; 4];
950        let l1 = vec![255i16; 4];
951        let mut out = vec![0u8; 4];
952        hevc_bipred_average(&l0, &l1, &mut out, 4);
953        for v in &out {
954            assert_eq!(*v, 255);
955        }
956
957        let l0n = vec![-10i16; 4];
958        let l1n = vec![-20i16; 4];
959        let mut outn = vec![255u8; 4];
960        hevc_bipred_average(&l0n, &l1n, &mut outn, 4);
961        for v in &outn {
962            assert_eq!(*v, 0);
963        }
964    }
965
966    #[test]
967    fn unipred_clip_basic() {
968        let pred = vec![128i16, -5, 300, 0];
969        let mut out = vec![0u8; 4];
970        hevc_unipred_clip(&pred, &mut out, 4);
971        assert_eq!(out[0], 128);
972        assert_eq!(out[1], 0);
973        assert_eq!(out[2], 255);
974        assert_eq!(out[3], 0);
975    }
976
977    // -- Merge candidate list tests -----------------------------------------
978
979    #[test]
980    fn merge_candidates_no_neighbours() {
981        // Empty motion field — all candidates should be zero-MV fill.
982        let field = vec![HevcMvField::unavailable(); 16 * 16];
983        let cands = build_merge_candidates(&field, 16, 0, 0, 8, 8);
984        assert_eq!(cands.len(), 5);
985        // All should have pred_flag[0] true (zero-MV default fill).
986        for c in &cands {
987            assert!(c.pred_flag[0]);
988        }
989    }
990
991    #[test]
992    fn merge_candidates_with_left_neighbour() {
993        let pw = 16usize; // pic_width_in_min_pu
994        let mut field = vec![HevcMvField::unavailable(); pw * pw];
995        // Set left neighbour at min-PU (0, 1) — that's A1 for a block at (4, 0).
996        let left = HevcMvField {
997            mv: [HevcMv { x: 8, y: 4 }, HevcMv::default()],
998            ref_idx: [0, -1],
999            pred_flag: [true, false],
1000        };
1001        // Block at pixel (4, 0), size 4x4 -> pu_x=1, pu_y=0, pu_w=1, pu_h=1.
1002        // A1 is at (pu_x-1, pu_y+pu_h-1) = (0, 0).
1003        field[0] = left;
1004        let cands = build_merge_candidates(&field, pw, 4, 0, 4, 4);
1005        assert_eq!(cands.len(), 5);
1006        assert_eq!(cands[0].mv[0].x, 8);
1007        assert_eq!(cands[0].mv[0].y, 4);
1008    }
1009
1010    // -- AMVP candidate tests -----------------------------------------------
1011
1012    #[test]
1013    fn amvp_no_neighbours() {
1014        let field = vec![HevcMvField::unavailable(); 16 * 16];
1015        let cands = build_amvp_candidates(&field, 16, 4, 4, 0, 0);
1016        // No neighbours -> zero MV.
1017        assert_eq!(cands[0], HevcMv::default());
1018        assert_eq!(cands[1], HevcMv::default());
1019    }
1020
1021    #[test]
1022    fn amvp_with_left_neighbour() {
1023        let pw = 16usize;
1024        let mut field = vec![HevcMvField::unavailable(); pw * pw];
1025        let left = HevcMvField {
1026            mv: [HevcMv { x: 12, y: -8 }, HevcMv::default()],
1027            ref_idx: [0, -1],
1028            pred_flag: [true, false],
1029        };
1030        // Block at pixel (4, 4) -> pu (1, 1). Left A0 at (0, 1).
1031        field[pw] = left;
1032        let cands = build_amvp_candidates(&field, pw, 4, 4, 0, 0);
1033        assert_eq!(cands[0].x, 12);
1034        assert_eq!(cands[0].y, -8);
1035    }
1036
1037    // -- MVD parsing tests --------------------------------------------------
1038
1039    #[test]
1040    fn parse_mvd_zero() {
1041        // All-zero stream: abs_x_gt0 = false, abs_y_gt0 = false -> (0, 0).
1042        let data = [0x00u8; 16];
1043        let mut cabac = CabacDecoder::new(&data);
1044        let mv = parse_mvd(&mut cabac);
1045        assert_eq!(mv.x, 0);
1046        assert_eq!(mv.y, 0);
1047    }
1048
1049    #[test]
1050    fn parse_mvd_deterministic() {
1051        // Non-trivial stream should produce deterministic MVD.
1052        let data = [0xFFu8; 32];
1053        let mut cabac = CabacDecoder::new(&data);
1054        let mv1 = parse_mvd(&mut cabac);
1055
1056        let mut cabac2 = CabacDecoder::new(&data);
1057        let mv2 = parse_mvd(&mut cabac2);
1058        assert_eq!(mv1, mv2);
1059    }
1060
1061    // -- parse_merge_idx tests ----------------------------------------------
1062
1063    #[test]
1064    fn parse_merge_idx_single_candidate() {
1065        let data = [0xFFu8; 8];
1066        let mut cabac = CabacDecoder::new(&data);
1067        let idx = parse_merge_idx(&mut cabac, 1);
1068        assert_eq!(idx, 0);
1069    }
1070
1071    #[test]
1072    fn parse_merge_idx_zero_stream() {
1073        let data = [0x00u8; 16];
1074        let mut cabac = CabacDecoder::new(&data);
1075        let idx = parse_merge_idx(&mut cabac, 5);
1076        assert_eq!(idx, 0);
1077    }
1078
1079    // -- HevcMvField tests --------------------------------------------------
1080
1081    #[test]
1082    fn mvfield_unavailable() {
1083        let f = HevcMvField::unavailable();
1084        assert!(!f.is_available());
1085        assert_eq!(f.ref_idx[0], -1);
1086        assert_eq!(f.ref_idx[1], -1);
1087    }
1088
1089    #[test]
1090    fn mvfield_available() {
1091        let f = HevcMvField {
1092            mv: [HevcMv { x: 1, y: 2 }, HevcMv::default()],
1093            ref_idx: [0, -1],
1094            pred_flag: [true, false],
1095        };
1096        assert!(f.is_available());
1097    }
1098
1099    // -- Full inter CU decode on synthetic data -----------------------------
1100
1101    #[test]
1102    fn parse_inter_prediction_p_slice_synthetic() {
1103        let data = [0x55u8; 128];
1104        let mut state = HevcSliceCabacState::new(&data, 26);
1105        let sps = test_sps();
1106        let field = vec![HevcMvField::unavailable(); 16 * 16];
1107        let mvf = parse_inter_prediction(&mut state, &sps, HevcSliceType::P, &field, 16, 0, 0, 8);
1108        // Should produce some prediction field.
1109        assert!(mvf.pred_flag[0] || mvf.pred_flag[1]);
1110    }
1111
1112    #[test]
1113    fn parse_inter_prediction_b_slice_synthetic() {
1114        let data = [0xAAu8; 128];
1115        let mut state = HevcSliceCabacState::new(&data, 26);
1116        let sps = test_sps();
1117        let field = vec![HevcMvField::unavailable(); 16 * 16];
1118        let mvf = parse_inter_prediction(&mut state, &sps, HevcSliceType::B, &field, 16, 0, 0, 8);
1119        assert!(mvf.pred_flag[0] || mvf.pred_flag[1]);
1120    }
1121
1122    #[test]
1123    fn parse_inter_prediction_merge_mode() {
1124        // Build a stream where merge_flag = true (first bypass bin = 1 in
1125        // the CABAC stream). We use 0xFF which makes bypass bins decode as 1.
1126        let data = [0xFFu8; 128];
1127        let mut state = HevcSliceCabacState::new(&data, 26);
1128        let sps = test_sps();
1129        let field = vec![HevcMvField::unavailable(); 16 * 16];
1130        let mvf = parse_inter_prediction(&mut state, &sps, HevcSliceType::P, &field, 16, 4, 4, 8);
1131        // Merge mode with no real neighbours falls back to zero-MV fill.
1132        assert!(mvf.pred_flag[0]);
1133    }
1134
1135    // -- Luma filter coefficient sanity -------------------------------------
1136
1137    #[test]
1138    fn luma_filter_coefficients_sum() {
1139        // Each set of filter coefficients should sum to 64.
1140        for (i, row) in HEVC_LUMA_FILTER.iter().enumerate() {
1141            let sum: i16 = row.iter().sum();
1142            assert_eq!(sum, 64, "filter row {i} sums to {sum}, expected 64");
1143        }
1144    }
1145
1146    // -- Helper for tests ---------------------------------------------------
1147
1148    fn test_sps() -> HevcSps {
1149        HevcSps {
1150            sps_id: 0,
1151            vps_id: 0,
1152            max_sub_layers: 1,
1153            chroma_format_idc: 1,
1154            pic_width: 64,
1155            pic_height: 64,
1156            bit_depth_luma: 8,
1157            bit_depth_chroma: 8,
1158            log2_max_pic_order_cnt: 4,
1159            log2_min_cb_size: 3,
1160            log2_diff_max_min_cb_size: 3,
1161            log2_min_transform_size: 2,
1162            log2_diff_max_min_transform_size: 3,
1163            max_transform_hierarchy_depth_inter: 1,
1164            max_transform_hierarchy_depth_intra: 1,
1165            sample_adaptive_offset_enabled: false,
1166            pcm_enabled: false,
1167            num_short_term_ref_pic_sets: 0,
1168            long_term_ref_pics_present: false,
1169            sps_temporal_mvp_enabled: false,
1170            strong_intra_smoothing_enabled: false,
1171        }
1172    }
1173}