Skip to main content

gamut_webp/vp8/
prediction.rs

1//! VP8 intra prediction (RFC 6386 §11–§12) and key-frame mode coding. Key-frame intra only.
2//!
3//! Key frames code, per macroblock, a luma 16×16 mode and a chroma 8×8 mode, each tree-coded over the
4//! boolean coder with fixed key-frame probabilities (§11.2, §11.4). [`predict_block`] handles the four
5//! whole-block modes (DC/V/H/TM); [`subblock_predict`] the ten 4×4 `B_PRED` submodes, whose modes are
6//! themselves context-coded against the above/left neighbors via [`KF_BMODE_PROB`] (§11.3). The
7//! [`LumaMode`] / [`SubBlockMode`] / [`ChromaMode`] enums name the full mode space; the `*_PRED` and
8//! `B_*_PRED` constants are the same values as the tree-leaf indices for the coders.
9
10use gamut_color::clip_pixel8;
11
12use super::bool_coder::{Prob, Tree};
13
14/// Luma 16×16 prediction mode (RFC 6386 §11.2, §12.3).
15#[derive(Debug, Clone, Copy, PartialEq, Eq)]
16pub enum LumaMode {
17    /// DC (average of available top/left edges) prediction.
18    Dc,
19    /// Vertical prediction from the row above.
20    Vertical,
21    /// Horizontal prediction from the column to the left.
22    Horizontal,
23    /// TrueMotion prediction (top row + left column − top-left corner).
24    TrueMotion,
25    /// Per-4×4-subblock prediction; selects a [`SubBlockMode`] for each of the 16 subblocks.
26    BPred,
27}
28
29/// Luma 4×4 subblock prediction mode, used when the macroblock mode is [`LumaMode::BPred`]
30/// (RFC 6386 §11.2, §12.3). Ten directional / averaging modes.
31#[derive(Debug, Clone, Copy, PartialEq, Eq)]
32pub enum SubBlockMode {
33    /// DC (average) prediction.
34    Dc,
35    /// TrueMotion prediction.
36    TrueMotion,
37    /// Vertical prediction.
38    Vertical,
39    /// Horizontal prediction.
40    Horizontal,
41    /// Down-left diagonal prediction.
42    LeftDown,
43    /// Down-right diagonal prediction.
44    RightDown,
45    /// Vertical-right diagonal prediction.
46    VerticalRight,
47    /// Vertical-left diagonal prediction.
48    VerticalLeft,
49    /// Horizontal-down diagonal prediction.
50    HorizontalDown,
51    /// Horizontal-up diagonal prediction.
52    HorizontalUp,
53}
54
55/// Chroma 8×8 prediction mode (RFC 6386 §12.2). The same four modes as the luma 16×16 set.
56#[derive(Debug, Clone, Copy, PartialEq, Eq)]
57pub enum ChromaMode {
58    /// DC (average) prediction.
59    Dc,
60    /// Vertical prediction.
61    Vertical,
62    /// Horizontal prediction.
63    Horizontal,
64    /// TrueMotion prediction.
65    TrueMotion,
66}
67
68/// `DC_PRED` whole-block mode index (matches [`LumaMode::Dc`] / [`ChromaMode::Dc`]).
69pub const DC_PRED: usize = 0;
70/// `V_PRED` (vertical) whole-block mode index.
71pub const V_PRED: usize = 1;
72/// `H_PRED` (horizontal) whole-block mode index.
73pub const H_PRED: usize = 2;
74/// `TM_PRED` (TrueMotion) whole-block mode index.
75pub const TM_PRED: usize = 3;
76/// `B_PRED` (per-4×4-subblock luma) mode index.
77pub const B_PRED: usize = 4;
78
79/// Key-frame luma 16×16 mode tree (RFC 6386 §11.2 / §8.2 `kf_ymode_tree`). Leaf values are the mode
80/// indices above (`-4` = `B_PRED`, `0` = `DC_PRED`, `-1` = `V_PRED`, …).
81#[rustfmt::skip]
82pub const KF_YMODE_TREE: &Tree = &[-4, 2, 4, 6, 0, -1, -2, -3];
83
84/// Key-frame luma 16×16 mode probabilities (RFC 6386 §11.2 `kf_ymode_prob`).
85pub const KF_YMODE_PROB: [Prob; 4] = [145, 156, 163, 128];
86
87/// Chroma 8×8 mode tree (RFC 6386 §11.4 / §8.2 `uv_mode_tree`).
88#[rustfmt::skip]
89pub const KF_UV_MODE_TREE: &Tree = &[0, 2, -1, 4, -2, -3];
90
91/// Key-frame chroma 8×8 mode probabilities (RFC 6386 §11.4 `kf_uv_mode_prob`).
92pub const KF_UV_MODE_PROB: [Prob; 3] = [142, 114, 183];
93
94/// Computes the single `DC_PRED` value filling an `n`×`n` block (RFC 6386 §12.2/§12.3): the rounded
95/// average of the available reconstructed neighbors. `above` is the `n`-pixel row immediately above
96/// the block and `left` the `n`-pixel column immediately to its left, each `None` when off-frame.
97/// With neither neighbor (the top-left block) the value is the constant 128; with one, only that
98/// neighbor is averaged (the spec's edge exception, *not* the 127/129 out-of-bounds fills).
99#[must_use]
100pub fn dc_predict(n: usize, above: Option<&[u8]>, left: Option<&[u8]>) -> u8 {
101    let sum_of = |pixels: &[u8]| pixels.iter().map(|&p| u32::from(p)).sum::<u32>();
102    let round_shift = |sum: u32, shf: u32| ((sum + (1 << (shf - 1))) >> shf) as u8;
103    match (above, left) {
104        // 2n summands -> shf = log2(2n) (5 for luma, 4 for chroma).
105        (Some(a), Some(l)) => round_shift(sum_of(a) + sum_of(l), (2 * n).trailing_zeros()),
106        // n summands -> shf = log2(n) (4 for luma, 3 for chroma).
107        (Some(a), None) => round_shift(sum_of(a), n.trailing_zeros()),
108        (None, Some(l)) => round_shift(sum_of(l), n.trailing_zeros()),
109        (None, None) => 128,
110    }
111}
112
113/// The `n` neighbor pixels, substituting `fill` for an off-frame edge.
114fn edge(pixels: Option<&[u8]>, fill: u8) -> [u8; 16] {
115    let mut e = [fill; 16];
116    if let Some(p) = pixels {
117        e[..p.len()].copy_from_slice(p);
118    }
119    e
120}
121
122/// Fills `out[..n*n]` (row-major) with the whole-block prediction for `mode` — `DC_PRED`, `V_PRED`,
123/// `H_PRED`, or `TM_PRED` (RFC 6386 §12.2/§12.3). `above`/`left` are the `n` reconstructed neighbor
124/// pixels (`None` off-frame: V/H/TM substitute the 127/129 out-of-bounds values, while DC uses its
125/// averaging edge exception); `corner` is the above-left pixel TrueMotion propagates from.
126pub fn predict_block(
127    mode: usize,
128    n: usize,
129    above: Option<&[u8]>,
130    left: Option<&[u8]>,
131    corner: u8,
132    out: &mut [u8],
133) {
134    match mode {
135        V_PRED => {
136            let a = edge(above, 127);
137            for r in 0..n {
138                out[r * n..r * n + n].copy_from_slice(&a[..n]);
139            }
140        }
141        H_PRED => {
142            let l = edge(left, 129);
143            for r in 0..n {
144                out[r * n..r * n + n].fill(l[r]);
145            }
146        }
147        TM_PRED => {
148            let a = edge(above, 127);
149            let l = edge(left, 129);
150            let p = i32::from(corner);
151            for r in 0..n {
152                for c in 0..n {
153                    out[r * n + c] = clip_pixel8(i32::from(l[r]) + i32::from(a[c]) - p);
154                }
155            }
156        }
157        _ => out[..n * n].fill(dc_predict(n, above, left)),
158    }
159}
160
161/// `B_DC_PRED` 4×4 subblock mode (RFC 6386 §11.2 `intra_bmode`); the leaf values of [`BMODE_TREE`].
162pub const B_DC_PRED: usize = 0;
163/// `B_TM_PRED` (TrueMotion) 4×4 subblock mode.
164pub const B_TM_PRED: usize = 1;
165/// `B_VE_PRED` (vertical, smoothed) 4×4 subblock mode.
166pub const B_VE_PRED: usize = 2;
167/// `B_HE_PRED` (horizontal, smoothed) 4×4 subblock mode.
168pub const B_HE_PRED: usize = 3;
169/// `B_LD_PRED` (left-down diagonal) 4×4 subblock mode.
170pub const B_LD_PRED: usize = 4;
171/// `B_RD_PRED` (right-down diagonal) 4×4 subblock mode.
172pub const B_RD_PRED: usize = 5;
173/// `B_VR_PRED` (vertical-right diagonal) 4×4 subblock mode.
174pub const B_VR_PRED: usize = 6;
175/// `B_VL_PRED` (vertical-left diagonal) 4×4 subblock mode.
176pub const B_VL_PRED: usize = 7;
177/// `B_HD_PRED` (horizontal-down diagonal) 4×4 subblock mode.
178pub const B_HD_PRED: usize = 8;
179/// `B_HU_PRED` (horizontal-up diagonal) 4×4 subblock mode.
180pub const B_HU_PRED: usize = 9;
181/// Number of 4×4 luma subblock prediction modes.
182pub const NUM_BMODES: usize = 10;
183
184/// Subblock-mode coding tree (RFC 6386 §11.2 `bmode_tree`). Leaf `-v` (and `0`) is mode `v`.
185#[rustfmt::skip]
186pub const BMODE_TREE: &Tree = &[
187     0,  2,   -1,  4,   -2,  6,    8, 12,
188    -3, 10,   -5, -6,   -4, 14,   -7, 16,
189    -8, -9,
190];
191
192/// Maps a whole-block luma mode to the subblock mode a non-`B_PRED` macroblock contributes to its
193/// neighbors' subblock-mode context (RFC 6386 §11.3 caveat 4).
194#[must_use]
195pub fn bmode_for_luma(luma_mode: usize) -> usize {
196    match luma_mode {
197        V_PRED => B_VE_PRED,
198        H_PRED => B_HE_PRED,
199        TM_PRED => B_TM_PRED,
200        _ => B_DC_PRED,
201    }
202}
203
204/// Weighted 3-tap average centered on `y` (RFC 6386 §12.3 `avg3`).
205fn avg3(x: i32, y: i32, z: i32) -> u8 {
206    ((x + 2 * y + z + 2) >> 2) as u8
207}
208
209/// Simple 2-tap average (RFC 6386 §12.3 `avg2`).
210fn avg2(x: i32, y: i32) -> u8 {
211    ((x + y + 1) >> 1) as u8
212}
213
214/// Predicts one 4×4 luma subblock under `B_PRED` submode `mode` (RFC 6386 §12.3). `above` holds the
215/// eight pixels `A[0..8]` of the row above (four directly above, four above-right), `left` the four
216/// `L[0..4]` to the left, and `corner` the above-left pixel `A[-1] == L[-1]`. Returns the row-major
217/// 4×4 prediction; the diagonal modes synthesize predictors from the `E` edge `[L3 L2 L1 L0 P A0..A3]`.
218#[must_use]
219pub fn subblock_predict(mode: usize, above: &[u8; 8], left: &[u8; 4], corner: u8) -> [u8; 16] {
220    // ax[0]=P, ax[1..9]=A[0..8]; lx[0]=P, lx[1..5]=L[0..4]; e = bottom-left..top-right edge.
221    let mut ax = [i32::from(corner); 9];
222    for i in 0..8 {
223        ax[i + 1] = i32::from(above[i]);
224    }
225    let mut lx = [i32::from(corner); 5];
226    for i in 0..4 {
227        lx[i + 1] = i32::from(left[i]);
228    }
229    let e = [
230        lx[4],
231        lx[3],
232        lx[2],
233        lx[1],
234        i32::from(corner),
235        ax[1],
236        ax[2],
237        ax[3],
238        ax[4],
239    ];
240    // B_DC_PRED (and any out-of-range mode): the no-edge-exception average of the eight neighbors.
241    if mode == B_DC_PRED || mode > B_HU_PRED {
242        let v = ((ax[1] + ax[2] + ax[3] + ax[4] + lx[1] + lx[2] + lx[3] + lx[4] + 4) >> 3) as u8;
243        return [v; 16];
244    }
245    let a3 = |j: usize| avg3(ax[j], ax[j + 1], ax[j + 2]); // avg3p(A + j)
246    let a2 = |j: usize| avg2(ax[j + 1], ax[j + 2]); // avg2p(A + j)
247    let l3 = |r: usize| avg3(lx[r], lx[r + 1], lx[r + 2]); // avg3p(L + r)
248    let e3 = |i: usize| avg3(e[i - 1], e[i], e[i + 1]); // avg3p(E + i)
249    let e2 = |i: usize| avg2(e[i], e[i + 1]); // avg2p(E + i)
250
251    let mut b = [0u8; 16];
252    let mut set = |r: usize, c: usize, v: u8| b[r * 4 + c] = v;
253    match mode {
254        B_TM_PRED => {
255            for r in 0..4 {
256                for c in 0..4 {
257                    set(r, c, clip_pixel8(lx[r + 1] + ax[c + 1] - i32::from(corner)));
258                }
259            }
260        }
261        B_VE_PRED => {
262            for c in 0..4 {
263                let v = a3(c);
264                for r in 0..4 {
265                    set(r, c, v);
266                }
267            }
268        }
269        B_HE_PRED => {
270            let rows = [l3(0), l3(1), l3(2), avg3(lx[3], lx[4], lx[4])];
271            for (r, &row) in rows.iter().enumerate() {
272                for c in 0..4 {
273                    set(r, c, row);
274                }
275            }
276        }
277        B_LD_PRED => {
278            set(0, 0, a3(1));
279            set(0, 1, a3(2));
280            set(1, 0, a3(2));
281            set(0, 2, a3(3));
282            set(1, 1, a3(3));
283            set(2, 0, a3(3));
284            set(0, 3, a3(4));
285            set(1, 2, a3(4));
286            set(2, 1, a3(4));
287            set(3, 0, a3(4));
288            set(1, 3, a3(5));
289            set(2, 2, a3(5));
290            set(3, 1, a3(5));
291            set(2, 3, a3(6));
292            set(3, 2, a3(6));
293            set(3, 3, avg3(ax[7], ax[8], ax[8]));
294        }
295        B_RD_PRED => {
296            set(3, 0, e3(1));
297            set(3, 1, e3(2));
298            set(2, 0, e3(2));
299            set(3, 2, e3(3));
300            set(2, 1, e3(3));
301            set(1, 0, e3(3));
302            set(3, 3, e3(4));
303            set(2, 2, e3(4));
304            set(1, 1, e3(4));
305            set(0, 0, e3(4));
306            set(2, 3, e3(5));
307            set(1, 2, e3(5));
308            set(0, 1, e3(5));
309            set(1, 3, e3(6));
310            set(0, 2, e3(6));
311            set(0, 3, e3(7));
312        }
313        B_VR_PRED => {
314            set(3, 0, e3(2));
315            set(2, 0, e3(3));
316            set(3, 1, e3(4));
317            set(1, 0, e3(4));
318            set(2, 1, e2(4));
319            set(0, 0, e2(4));
320            set(3, 2, e3(5));
321            set(1, 1, e3(5));
322            set(2, 2, e2(5));
323            set(0, 1, e2(5));
324            set(3, 3, e3(6));
325            set(1, 2, e3(6));
326            set(2, 3, e2(6));
327            set(0, 2, e2(6));
328            set(1, 3, e3(7));
329            set(0, 3, e2(7));
330        }
331        B_VL_PRED => {
332            set(0, 0, a2(0));
333            set(1, 0, a3(1));
334            set(2, 0, a2(1));
335            set(0, 1, a2(1));
336            set(1, 1, a3(2));
337            set(3, 0, a3(2));
338            set(2, 1, a2(2));
339            set(0, 2, a2(2));
340            set(3, 1, a3(3));
341            set(1, 2, a3(3));
342            set(2, 2, a2(3));
343            set(0, 3, a2(3));
344            set(3, 2, a3(4));
345            set(1, 3, a3(4));
346            set(2, 3, a3(5));
347            set(3, 3, a3(6));
348        }
349        B_HD_PRED => {
350            set(3, 0, e2(0));
351            set(3, 1, e3(1));
352            set(2, 0, e2(1));
353            set(3, 2, e2(1));
354            set(2, 1, e3(2));
355            set(3, 3, e3(2));
356            set(2, 2, e2(2));
357            set(1, 0, e2(2));
358            set(2, 3, e3(3));
359            set(1, 1, e3(3));
360            set(1, 2, e2(3));
361            set(0, 0, e2(3));
362            set(1, 3, e3(4));
363            set(0, 1, e3(4));
364            set(0, 2, e3(5));
365            set(0, 3, e3(6));
366        }
367        B_HU_PRED => {
368            set(0, 0, avg2(lx[1], lx[2]));
369            set(0, 1, l3(1));
370            set(0, 2, avg2(lx[2], lx[3]));
371            set(1, 0, avg2(lx[2], lx[3]));
372            set(0, 3, l3(2));
373            set(1, 1, l3(2));
374            set(1, 2, avg2(lx[3], lx[4]));
375            set(2, 0, avg2(lx[3], lx[4]));
376            set(1, 3, avg3(lx[3], lx[4], lx[4]));
377            set(2, 1, avg3(lx[3], lx[4], lx[4]));
378            let l_last = lx[4] as u8;
379            for (r, c) in [(2, 2), (2, 3), (3, 0), (3, 1), (3, 2), (3, 3)] {
380                set(r, c, l_last);
381            }
382        }
383        _ => {
384            // B_DC_PRED: average of the four above + four left, no edge exception.
385            let v =
386                ((ax[1] + ax[2] + ax[3] + ax[4] + lx[1] + lx[2] + lx[3] + lx[4] + 4) >> 3) as u8;
387            b = [v; 16];
388        }
389    }
390    b
391}
392
393/// Key-frame subblock-mode probabilities `[above_mode][left_mode][tree_node]` (RFC 6386 §11.3/§11.5).
394#[rustfmt::skip]
395pub const KF_BMODE_PROB: [[[Prob; 9]; NUM_BMODES]; NUM_BMODES] = [
396    [
397        [231, 120,  48,  89, 115, 113, 120, 152, 112],
398        [152, 179,  64, 126, 170, 118,  46,  70,  95],
399        [175,  69, 143,  80,  85,  82,  72, 155, 103],
400        [ 56,  58,  10, 171, 218, 189,  17,  13, 152],
401        [144,  71,  10,  38, 171, 213, 144,  34,  26],
402        [114,  26,  17, 163,  44, 195,  21,  10, 173],
403        [121,  24,  80, 195,  26,  62,  44,  64,  85],
404        [170,  46,  55,  19, 136, 160,  33, 206,  71],
405        [ 63,  20,   8, 114, 114, 208,  12,   9, 226],
406        [ 81,  40,  11,  96, 182,  84,  29,  16,  36],
407    ],
408    [
409        [134, 183,  89, 137,  98, 101, 106, 165, 148],
410        [ 72, 187, 100, 130, 157, 111,  32,  75,  80],
411        [ 66, 102, 167,  99,  74,  62,  40, 234, 128],
412        [ 41,  53,   9, 178, 241, 141,  26,   8, 107],
413        [104,  79,  12,  27, 217, 255,  87,  17,   7],
414        [ 74,  43,  26, 146,  73, 166,  49,  23, 157],
415        [ 65,  38, 105, 160,  51,  52,  31, 115, 128],
416        [ 87,  68,  71,  44, 114,  51,  15, 186,  23],
417        [ 47,  41,  14, 110, 182, 183,  21,  17, 194],
418        [ 66,  45,  25, 102, 197, 189,  23,  18,  22],
419    ],
420    [
421        [ 88,  88, 147, 150,  42,  46,  45, 196, 205],
422        [ 43,  97, 183, 117,  85,  38,  35, 179,  61],
423        [ 39,  53, 200,  87,  26,  21,  43, 232, 171],
424        [ 56,  34,  51, 104, 114, 102,  29,  93,  77],
425        [107,  54,  32,  26,  51,   1,  81,  43,  31],
426        [ 39,  28,  85, 171,  58, 165,  90,  98,  64],
427        [ 34,  22, 116, 206,  23,  34,  43, 166,  73],
428        [ 68,  25, 106,  22,  64, 171,  36, 225, 114],
429        [ 34,  19,  21, 102, 132, 188,  16,  76, 124],
430        [ 62,  18,  78,  95,  85,  57,  50,  48,  51],
431    ],
432    [
433        [193, 101,  35, 159, 215, 111,  89,  46, 111],
434        [ 60, 148,  31, 172, 219, 228,  21,  18, 111],
435        [112, 113,  77,  85, 179, 255,  38, 120, 114],
436        [ 40,  42,   1, 196, 245, 209,  10,  25, 109],
437        [100,  80,   8,  43, 154,   1,  51,  26,  71],
438        [ 88,  43,  29, 140, 166, 213,  37,  43, 154],
439        [ 61,  63,  30, 155,  67,  45,  68,   1, 209],
440        [142,  78,  78,  16, 255, 128,  34, 197, 171],
441        [ 41,  40,   5, 102, 211, 183,   4,   1, 221],
442        [ 51,  50,  17, 168, 209, 192,  23,  25,  82],
443    ],
444    [
445        [125,  98,  42,  88, 104,  85, 117, 175,  82],
446        [ 95,  84,  53,  89, 128, 100, 113, 101,  45],
447        [ 75,  79, 123,  47,  51, 128,  81, 171,   1],
448        [ 57,  17,   5,  71, 102,  57,  53,  41,  49],
449        [115,  21,   2,  10, 102, 255, 166,  23,   6],
450        [ 38,  33,  13, 121,  57,  73,  26,   1,  85],
451        [ 41,  10,  67, 138,  77, 110,  90,  47, 114],
452        [101,  29,  16,  10,  85, 128, 101, 196,  26],
453        [ 57,  18,  10, 102, 102, 213,  34,  20,  43],
454        [117,  20,  15,  36, 163, 128,  68,   1,  26],
455    ],
456    [
457        [138,  31,  36, 171,  27, 166,  38,  44, 229],
458        [ 67,  87,  58, 169,  82, 115,  26,  59, 179],
459        [ 63,  59,  90, 180,  59, 166,  93,  73, 154],
460        [ 40,  40,  21, 116, 143, 209,  34,  39, 175],
461        [ 57,  46,  22,  24, 128,   1,  54,  17,  37],
462        [ 47,  15,  16, 183,  34, 223,  49,  45, 183],
463        [ 46,  17,  33, 183,   6,  98,  15,  32, 183],
464        [ 65,  32,  73, 115,  28, 128,  23, 128, 205],
465        [ 40,   3,   9, 115,  51, 192,  18,   6, 223],
466        [ 87,  37,   9, 115,  59,  77,  64,  21,  47],
467    ],
468    [
469        [104,  55,  44, 218,   9,  54,  53, 130, 226],
470        [ 64,  90,  70, 205,  40,  41,  23,  26,  57],
471        [ 54,  57, 112, 184,   5,  41,  38, 166, 213],
472        [ 30,  34,  26, 133, 152, 116,  10,  32, 134],
473        [ 75,  32,  12,  51, 192, 255, 160,  43,  51],
474        [ 39,  19,  53, 221,  26, 114,  32,  73, 255],
475        [ 31,   9,  65, 234,   2,  15,   1, 118,  73],
476        [ 88,  31,  35,  67, 102,  85,  55, 186,  85],
477        [ 56,  21,  23, 111,  59, 205,  45,  37, 192],
478        [ 55,  38,  70, 124,  73, 102,   1,  34,  98],
479    ],
480    [
481        [102,  61,  71,  37,  34,  53,  31, 243, 192],
482        [ 69,  60,  71,  38,  73, 119,  28, 222,  37],
483        [ 68,  45, 128,  34,   1,  47,  11, 245, 171],
484        [ 62,  17,  19,  70, 146,  85,  55,  62,  70],
485        [ 75,  15,   9,   9,  64, 255, 184, 119,  16],
486        [ 37,  43,  37, 154, 100, 163,  85, 160,   1],
487        [ 63,   9,  92, 136,  28,  64,  32, 201,  85],
488        [ 86,   6,  28,   5,  64, 255,  25, 248,   1],
489        [ 56,   8,  17, 132, 137, 255,  55, 116, 128],
490        [ 58,  15,  20,  82, 135,  57,  26, 121,  40],
491    ],
492    [
493        [164,  50,  31, 137, 154, 133,  25,  35, 218],
494        [ 51, 103,  44, 131, 131, 123,  31,   6, 158],
495        [ 86,  40,  64, 135, 148, 224,  45, 183, 128],
496        [ 22,  26,  17, 131, 240, 154,  14,   1, 209],
497        [ 83,  12,  13,  54, 192, 255,  68,  47,  28],
498        [ 45,  16,  21,  91,  64, 222,   7,   1, 197],
499        [ 56,  21,  39, 155,  60, 138,  23, 102, 213],
500        [ 85,  26,  85,  85, 128, 128,  32, 146, 171],
501        [ 18,  11,   7,  63, 144, 171,   4,   4, 246],
502        [ 35,  27,  10, 146, 174, 171,  12,  26, 128],
503    ],
504    [
505        [190,  80,  35,  99, 180,  80, 126,  54,  45],
506        [ 85, 126,  47,  87, 176,  51,  41,  20,  32],
507        [101,  75, 128, 139, 118, 146, 116, 128,  85],
508        [ 56,  41,  15, 176, 236,  85,  37,   9,  62],
509        [146,  36,  19,  30, 171, 255,  97,  27,  20],
510        [ 71,  30,  17, 119, 118, 255,  17,  18, 138],
511        [101,  38,  60, 138,  55,  70,  43,  26, 142],
512        [138,  45,  61,  62, 219,   1,  81, 188,  64],
513        [ 32,  41,  20, 117, 151, 142,  20,  21, 163],
514        [112,  19,  12,  61, 195, 128,  48,   4,  24],
515    ],
516];
517
518#[cfg(test)]
519mod tests {
520    use super::*;
521
522    #[test]
523    fn dc_top_left_is_128() {
524        assert_eq!(dc_predict(16, None, None), 128);
525        assert_eq!(dc_predict(8, None, None), 128);
526    }
527
528    #[test]
529    fn dc_single_edge_averages_that_edge() {
530        let above = [100u8; 16];
531        assert_eq!(dc_predict(16, Some(&above), None), 100);
532        let left = [10u8, 10, 10, 10, 20, 20, 20, 20];
533        let expected = ((10 * 4 + 20 * 4 + 4) >> 3) as u8;
534        assert_eq!(dc_predict(8, None, Some(&left)), expected);
535    }
536
537    #[test]
538    fn dc_both_edges_average_all() {
539        let a = [64u8; 16];
540        let l = [64u8; 16];
541        assert_eq!(dc_predict(16, Some(&a), Some(&l)), 64);
542        let a2 = [200u8; 16];
543        let l2 = [0u8; 16];
544        assert_eq!(
545            dc_predict(16, Some(&a2), Some(&l2)),
546            ((200 * 16 + 16) >> 5) as u8
547        );
548    }
549
550    #[test]
551    fn mode_trees_are_well_formed() {
552        assert_eq!(KF_YMODE_TREE.len(), 8);
553        assert_eq!(KF_UV_MODE_TREE.len(), 6);
554        assert_eq!(LumaMode::Dc as usize, DC_PRED);
555        assert_eq!(ChromaMode::TrueMotion as usize, TM_PRED);
556    }
557
558    #[test]
559    fn vertical_prediction_copies_the_above_row() {
560        let above: Vec<u8> = (0..16).map(|i| (i * 10) as u8).collect();
561        let mut out = [0u8; 256];
562        predict_block(V_PRED, 16, Some(&above), Some(&[50u8; 16]), 200, &mut out);
563        for r in 0..16 {
564            assert_eq!(&out[r * 16..r * 16 + 16], &above[..]);
565        }
566    }
567
568    #[test]
569    fn horizontal_prediction_copies_the_left_column() {
570        let left: Vec<u8> = (0..16).map(|i| (i * 10) as u8).collect();
571        let mut out = [0u8; 256];
572        predict_block(H_PRED, 16, Some(&[50u8; 16]), Some(&left), 200, &mut out);
573        for r in 0..16 {
574            assert!(out[r * 16..r * 16 + 16].iter().all(|&p| p == left[r]));
575        }
576    }
577
578    #[test]
579    fn truemotion_propagates_from_the_corner() {
580        // X[r][c] = clip_pixel8(L[r] + A[c] - P).
581        let above = [10u8, 20, 30, 40];
582        let left = [100u8, 110, 120, 130];
583        let p = 50i32;
584        let mut out = [0u8; 16];
585        predict_block(TM_PRED, 4, Some(&above), Some(&left), p as u8, &mut out);
586        for r in 0..4 {
587            for c in 0..4 {
588                let expect = clip_pixel8(i32::from(left[r]) + i32::from(above[c]) - p);
589                assert_eq!(out[r * 4 + c], expect, "TM at ({r},{c})");
590            }
591        }
592    }
593
594    #[test]
595    fn off_frame_edges_use_127_and_129() {
596        let mut out = [0u8; 256];
597        predict_block(V_PRED, 16, None, Some(&[5u8; 16]), 0, &mut out);
598        assert!(
599            out.iter().all(|&p| p == 127),
600            "vertical off the top row fills 127"
601        );
602        predict_block(H_PRED, 16, Some(&[5u8; 16]), None, 0, &mut out);
603        assert!(
604            out.iter().all(|&p| p == 129),
605            "horizontal off the left column fills 129"
606        );
607    }
608
609    #[test]
610    fn subblock_dc_averages_the_eight_neighbors() {
611        let a = [10u8, 20, 30, 40, 0, 0, 0, 0]; // above-right is unused by B_DC_PRED
612        let l = [50u8, 60, 70, 80];
613        let out = subblock_predict(B_DC_PRED, &a, &l, 99);
614        let v = ((10 + 20 + 30 + 40 + 50 + 60 + 70 + 80 + 4) >> 3) as u8;
615        assert!(out.iter().all(|&p| p == v));
616    }
617
618    #[test]
619    fn subblock_tm_matches_left_plus_above_minus_corner() {
620        let a = [60u8, 70, 80, 90, 0, 0, 0, 0];
621        let l = [100u8, 110, 120, 130];
622        let p = 50u8;
623        let out = subblock_predict(B_TM_PRED, &a, &l, p);
624        for r in 0..4 {
625            for c in 0..4 {
626                let want = clip_pixel8(i32::from(l[r]) + i32::from(a[c]) - i32::from(p));
627                assert_eq!(out[r * 4 + c], want, "TM at ({r},{c})");
628            }
629        }
630    }
631
632    #[test]
633    fn subblock_ve_smooths_the_above_row() {
634        // B[*][c] = avg3p(A + c) = (A[c-1] + 2*A[c] + A[c+1] + 2) >> 2, with A[-1] the corner.
635        let a = [10u8, 20, 30, 40, 50, 0, 0, 0];
636        let out = subblock_predict(B_VE_PRED, &a, &[0u8; 4], 5);
637        let ext = [5i32, 10, 20, 30, 40, 50]; // [corner, A0..A4]
638        for c in 0..4 {
639            let want = ((ext[c] + 2 * ext[c + 1] + ext[c + 2] + 2) >> 2) as u8;
640            for r in 0..4 {
641                assert_eq!(out[r * 4 + c], want, "VE col {c}");
642            }
643        }
644    }
645
646    #[test]
647    fn kf_bmode_prob_table_shape_and_corners() {
648        assert_eq!(KF_BMODE_PROB.len(), NUM_BMODES);
649        assert_eq!(KF_BMODE_PROB[0].len(), NUM_BMODES);
650        assert_eq!(KF_BMODE_PROB[0][0].len(), 9);
651        // Corner entries from RFC 6386 §11.5.
652        assert_eq!(KF_BMODE_PROB[0][0][0], 231);
653        assert_eq!(KF_BMODE_PROB[9][9], [112, 19, 12, 61, 195, 128, 48, 4, 24]);
654        assert_eq!(BMODE_TREE.len(), 2 * (NUM_BMODES - 1));
655    }
656}