Skip to main content

oximedia_codec/av1/
entropy_tables.rs

1//! AV1 entropy coding default CDF tables.
2//!
3//! This module contains the default probability tables (CDFs) used for
4//! entropy coding in AV1. These tables are used as initial values and
5//! are updated during decoding based on observed symbol frequencies.
6//!
7//! # CDF Format
8//!
9//! CDFs are stored as arrays of 16-bit unsigned integers. The last element
10//! is reserved for the symbol count used in CDF update. The actual
11//! probabilities are in 15-bit precision (0-32767).
12//!
13//! # Table Categories
14//!
15//! - **Partition CDFs** - Block partitioning decisions
16//! - **Intra Mode CDFs** - Intra prediction mode selection
17//! - **TX Size CDFs** - Transform size selection
18//! - **TX Type CDFs** - Transform type selection
19//! - **Coefficient CDFs** - Coefficient level and sign coding
20//! - **MV Component CDFs** - Motion vector component coding
21//!
22//! # Reference
23//!
24//! See AV1 Specification Section 9 for probability model initialization
25//! and update procedures.
26
27#![allow(dead_code)]
28#![allow(clippy::doc_markdown)]
29#![allow(clippy::needless_range_loop)]
30
31// =============================================================================
32// CDF Precision Constants
33// =============================================================================
34
35/// CDF precision bits.
36pub const CDF_PROB_BITS: u8 = 15;
37
38/// Maximum CDF probability value.
39pub const CDF_PROB_TOP: u16 = 1 << CDF_PROB_BITS;
40
41/// Initial symbol count for CDF adaptation.
42pub const CDF_INIT_COUNT: u16 = 0;
43
44/// Maximum symbol count for CDF adaptation rate.
45pub const CDF_MAX_COUNT: u16 = 32;
46
47// =============================================================================
48// Partition CDFs
49// =============================================================================
50
51/// Number of partition contexts.
52pub const PARTITION_CONTEXTS: usize = 4;
53
54/// Number of partition types.
55pub const PARTITION_TYPES: usize = 10;
56
57/// Default CDF for partition (context 0, small blocks).
58pub const DEFAULT_PARTITION_CDF_0: [u16; 11] = [
59    15588, 17570, 19323, 21084, 22472, 24311, 25744, 27999, 29223, 32768, 0,
60];
61
62/// Default CDF for partition (context 1, medium blocks).
63pub const DEFAULT_PARTITION_CDF_1: [u16; 11] = [
64    12064, 14616, 17239, 19824, 21631, 24068, 25919, 28400, 29760, 32768, 0,
65];
66
67/// Default CDF for partition (context 2, large blocks).
68pub const DEFAULT_PARTITION_CDF_2: [u16; 11] = [
69    9216, 12096, 15424, 18432, 20672, 23424, 25664, 28544, 30080, 32768, 0,
70];
71
72/// Default CDF for partition (context 3, very large blocks).
73pub const DEFAULT_PARTITION_CDF_3: [u16; 11] = [
74    6144, 9472, 13312, 16896, 19584, 22912, 25600, 28800, 30464, 32768, 0,
75];
76
77/// All default partition CDFs.
78pub const DEFAULT_PARTITION_CDFS: [[u16; 11]; PARTITION_CONTEXTS] = [
79    DEFAULT_PARTITION_CDF_0,
80    DEFAULT_PARTITION_CDF_1,
81    DEFAULT_PARTITION_CDF_2,
82    DEFAULT_PARTITION_CDF_3,
83];
84
85// =============================================================================
86// Intra Mode CDFs
87// =============================================================================
88
89/// Number of intra modes.
90pub const INTRA_MODES: usize = 13;
91
92/// Number of intra mode contexts for Y.
93pub const INTRA_Y_MODE_CONTEXTS: usize = 4;
94
95/// Default CDF for Y intra mode (context 0).
96pub const DEFAULT_Y_MODE_CDF_0: [u16; 14] = [
97    15588, 17570, 18800, 20000, 21500, 23000, 24500, 26000, 27500, 29000, 30500, 31500, 32768, 0,
98];
99
100/// Default CDF for Y intra mode (context 1).
101pub const DEFAULT_Y_MODE_CDF_1: [u16; 14] = [
102    12064, 14616, 16500, 18500, 20500, 22500, 24500, 26500, 28000, 29500, 30750, 31750, 32768, 0,
103];
104
105/// Default CDF for Y intra mode (context 2).
106pub const DEFAULT_Y_MODE_CDF_2: [u16; 14] = [
107    9216, 12096, 14500, 17000, 19500, 22000, 24500, 26500, 28500, 30000, 31000, 32000, 32768, 0,
108];
109
110/// Default CDF for Y intra mode (context 3).
111pub const DEFAULT_Y_MODE_CDF_3: [u16; 14] = [
112    6144, 9472, 12500, 15500, 18500, 21500, 24500, 27000, 29000, 30500, 31250, 32000, 32768, 0,
113];
114
115/// All default Y mode CDFs.
116pub const DEFAULT_Y_MODE_CDFS: [[u16; 14]; INTRA_Y_MODE_CONTEXTS] = [
117    DEFAULT_Y_MODE_CDF_0,
118    DEFAULT_Y_MODE_CDF_1,
119    DEFAULT_Y_MODE_CDF_2,
120    DEFAULT_Y_MODE_CDF_3,
121];
122
123/// Number of UV intra mode contexts.
124pub const INTRA_UV_MODE_CONTEXTS: usize = 13;
125
126/// Default CDF for UV intra mode (for CFL disabled).
127pub const DEFAULT_UV_MODE_CDF_NO_CFL: [u16; 14] = [
128    22528, 24320, 25344, 26368, 27136, 28160, 28928, 29696, 30464, 31104, 31616, 32128, 32768, 0,
129];
130
131/// Default CDF for UV intra mode (for CFL enabled).
132pub const DEFAULT_UV_MODE_CDF_CFL: [u16; 15] = [
133    18432, 20480, 22016, 23296, 24576, 25856, 27136, 28160, 29184, 30080, 30848, 31488, 32000,
134    32768, 0,
135];
136
137// =============================================================================
138// Transform Size CDFs
139// =============================================================================
140
141/// Number of TX size contexts.
142pub const TX_SIZE_CONTEXTS: usize = 3;
143
144/// Number of max TX size categories.
145pub const MAX_TX_CATS: usize = 4;
146
147/// Default CDF for TX size (max 8x8).
148pub const DEFAULT_TX_SIZE_CDF_8X8: [u16; 3] = [16384, 32768, 0];
149
150/// Default CDF for TX size (max 16x16).
151pub const DEFAULT_TX_SIZE_CDF_16X16: [u16; 4] = [10923, 21845, 32768, 0];
152
153/// Default CDF for TX size (max 32x32).
154pub const DEFAULT_TX_SIZE_CDF_32X32: [u16; 5] = [8192, 16384, 24576, 32768, 0];
155
156/// Default CDF for TX size (max 64x64).
157pub const DEFAULT_TX_SIZE_CDF_64X64: [u16; 6] = [6554, 13107, 19661, 26214, 32768, 0];
158
159// =============================================================================
160// Transform Type CDFs
161// =============================================================================
162
163/// Number of TX type contexts per set.
164pub const TX_TYPE_CONTEXTS: usize = 7;
165
166/// Number of transform types for intra.
167pub const INTRA_TX_TYPES: usize = 7;
168
169/// Number of transform types for inter.
170pub const INTER_TX_TYPES: usize = 16;
171
172/// Default CDF for intra TX type (TX_4X4).
173pub const DEFAULT_INTRA_TX_TYPE_4X4: [[u16; 8]; TX_TYPE_CONTEXTS] = [
174    [5461, 10923, 16384, 21845, 24576, 27307, 30037, 32768],
175    [4681, 9362, 14043, 18725, 22118, 25512, 28905, 32768],
176    [4096, 8192, 12288, 16384, 20480, 24576, 28672, 32768],
177    [3641, 7282, 10923, 14564, 18893, 23222, 27551, 32768],
178    [3277, 6554, 9830, 13107, 17476, 21845, 26214, 32768],
179    [2979, 5958, 8937, 11916, 16213, 20511, 25228, 32768],
180    [2731, 5461, 8192, 10923, 15019, 19114, 24064, 32768],
181];
182
183/// Default CDF for intra TX type (TX_8X8).
184pub const DEFAULT_INTRA_TX_TYPE_8X8: [[u16; 8]; TX_TYPE_CONTEXTS] = [
185    [6144, 12288, 18432, 24576, 26624, 28672, 30720, 32768],
186    [5461, 10923, 16384, 21845, 24576, 27307, 30037, 32768],
187    [4915, 9830, 14745, 19660, 22938, 26214, 29491, 32768],
188    [4455, 8909, 13364, 17818, 21399, 24980, 28561, 32768],
189    [4096, 8192, 12288, 16384, 20070, 23756, 27852, 32768],
190    [3780, 7559, 11339, 15119, 18897, 22675, 27200, 32768],
191    [3495, 6991, 10486, 13981, 17827, 21673, 26600, 32768],
192];
193
194/// Default CDF for inter TX type.
195pub const DEFAULT_INTER_TX_TYPE: [[u16; 17]; TX_TYPE_CONTEXTS] = [
196    [
197        2048, 4096, 6144, 8192, 10240, 12288, 14336, 16384, 18432, 20480, 22528, 24576, 26624,
198        28672, 30720, 32768, 0,
199    ],
200    [
201        2048, 4096, 6144, 8192, 10240, 12288, 14336, 16384, 18432, 20480, 22528, 24576, 26624,
202        28672, 30720, 32768, 0,
203    ],
204    [
205        2048, 4096, 6144, 8192, 10240, 12288, 14336, 16384, 18432, 20480, 22528, 24576, 26624,
206        28672, 30720, 32768, 0,
207    ],
208    [
209        2048, 4096, 6144, 8192, 10240, 12288, 14336, 16384, 18432, 20480, 22528, 24576, 26624,
210        28672, 30720, 32768, 0,
211    ],
212    [
213        2048, 4096, 6144, 8192, 10240, 12288, 14336, 16384, 18432, 20480, 22528, 24576, 26624,
214        28672, 30720, 32768, 0,
215    ],
216    [
217        2048, 4096, 6144, 8192, 10240, 12288, 14336, 16384, 18432, 20480, 22528, 24576, 26624,
218        28672, 30720, 32768, 0,
219    ],
220    [
221        2048, 4096, 6144, 8192, 10240, 12288, 14336, 16384, 18432, 20480, 22528, 24576, 26624,
222        28672, 30720, 32768, 0,
223    ],
224];
225
226// =============================================================================
227// Coefficient CDFs
228// =============================================================================
229
230/// Number of EOB multi contexts per plane type.
231///
232/// Per AV1 Annex F §9.5, EOB symbol coding is parameterised by the transform
233/// area class. The seven classes correspond to areas 16, 32, 64, 128, 256,
234/// 512, and 1024+ — matching the seven distinct `eob_multi*_cdf` tables in
235/// the reference decoder (`EobMultiSize16Cdf` through `EobMultiSize1024Cdf`).
236pub const EOB_MULTI_CONTEXTS: usize = 7;
237
238/// Number of transform sizes in AV1 (matches `TxSize` enum, indices 0..=18).
239pub const TX_SIZE_COUNT: usize = 19;
240
241/// Number of plane types (luma vs. one of two chroma planes).
242pub const EOB_PLANE_COUNT: usize = 3;
243
244/// Number of `(tx_size, plane)` contexts addressed by the caller.
245///
246/// The decoder/encoder packs the EOB context as
247/// `ctx = tx_size_idx * EOB_PLANE_COUNT + plane`, with `tx_size_idx` in
248/// `0..TX_SIZE_COUNT` and `plane` in `0..EOB_PLANE_COUNT`. The product is
249/// the number of distinct CDF slots maintained per frame.
250pub const EOB_MULTI_TOTAL_CONTEXTS: usize = TX_SIZE_COUNT * EOB_PLANE_COUNT;
251
252/// Default CDF for EOB multi (2 symbols).
253pub const DEFAULT_EOB_MULTI_2: [u16; 3] = [16384, 32768, 0];
254
255/// Default CDF for EOB multi (4 symbols).
256pub const DEFAULT_EOB_MULTI_4: [u16; 5] = [8192, 16384, 24576, 32768, 0];
257
258/// Default CDF for EOB multi (8 symbols).
259pub const DEFAULT_EOB_MULTI_8: [u16; 9] = [4096, 8192, 12288, 16384, 20480, 24576, 28672, 32768, 0];
260
261/// Default CDF for EOB multi (16 symbols).
262pub const DEFAULT_EOB_MULTI_16: [u16; 17] = [
263    2048, 4096, 6144, 8192, 10240, 12288, 14336, 16384, 18432, 20480, 22528, 24576, 26624, 28672,
264    30720, 32768, 0,
265];
266
267/// Map a `TxSize` integer index to the appropriate default EOB-multi CDF.
268///
269/// The class is chosen by transform area, matching the maximum value of
270/// [`crate::av1::coefficients::EobPt::from_eob`] for that area:
271///
272/// | Area   | TxSizes                       | max `EobPt` | CDF picked |
273/// |--------|-------------------------------|-------------|------------|
274/// | 16     | 4×4                           | 5           | `_8` (9)   |
275/// | 32     | 4×8, 8×4                      | 6           | `_8` (9)   |
276/// | 64     | 8×8, 4×16, 16×4               | 7           | `_8` (9)   |
277/// | 128    | 8×16, 16×8                    | 8           | `_16` (17) |
278/// | 256    | 16×16, 8×32, 32×8             | 9           | `_16` (17) |
279/// | 512    | 16×32, 32×16                  | 10          | `_16` (17) |
280/// | ≥ 1024 | 32×32, 32×64, 64×32, 64×64, … | 11          | `_16` (17) |
281///
282/// `DEFAULT_EOB_MULTI_2` and `DEFAULT_EOB_MULTI_4` are retained for
283/// reference (they appear in the public spec text) but are not selected
284/// because no transform area yields a max EOB point ≤ 3 in the present
285/// decoder.
286#[must_use]
287fn default_eob_multi_for_tx_size_idx(tx_size_idx: usize) -> Vec<u16> {
288    // Areas of TxSize variants in declaration order. Out-of-range
289    // indices fall through to the safest (largest) bucket.
290    const TX_SIZE_AREAS: [u32; TX_SIZE_COUNT] = [
291        16,   // Tx4x4
292        64,   // Tx8x8
293        256,  // Tx16x16
294        1024, // Tx32x32
295        4096, // Tx64x64
296        32,   // Tx4x8
297        32,   // Tx8x4
298        128,  // Tx8x16
299        128,  // Tx16x8
300        512,  // Tx16x32
301        512,  // Tx32x16
302        2048, // Tx32x64
303        2048, // Tx64x32
304        64,   // Tx4x16
305        64,   // Tx16x4
306        256,  // Tx8x32
307        256,  // Tx32x8
308        1024, // Tx16x64
309        1024, // Tx64x16
310    ];
311
312    let area = TX_SIZE_AREAS.get(tx_size_idx).copied().unwrap_or(4096);
313
314    if area <= 64 {
315        DEFAULT_EOB_MULTI_8.to_vec()
316    } else {
317        DEFAULT_EOB_MULTI_16.to_vec()
318    }
319}
320
321/// Number of coefficient base contexts.
322pub const COEFF_BASE_CTX_COUNT: usize = 42;
323
324/// Default CDF for coefficient base (4 levels: 0, 1, 2, >2).
325pub const DEFAULT_COEFF_BASE_CDF: [u16; 5] = [8192, 16384, 24576, 32768, 0];
326
327/// Default CDF for coefficient base EOB.
328pub const DEFAULT_COEFF_BASE_EOB_CDF: [u16; 4] = [10923, 21845, 32768, 0];
329
330/// Number of DC sign contexts.
331pub const DC_SIGN_CTX_COUNT: usize = 3;
332
333/// Default CDF for DC sign.
334pub const DEFAULT_DC_SIGN_CDF: [u16; 3] = [16384, 32768, 0];
335
336/// Number of coefficient base range contexts.
337pub const COEFF_BR_CTX_COUNT: usize = 21;
338
339/// Default CDF for coefficient base range.
340pub const DEFAULT_COEFF_BR_CDF: [u16; 4] = [10923, 21845, 32768, 0];
341
342// =============================================================================
343// Motion Vector CDFs
344// =============================================================================
345
346/// Number of MV joint types.
347pub const MV_JOINTS: usize = 4;
348
349/// Default CDF for MV joint.
350pub const DEFAULT_MV_JOINT_CDF: [u16; 5] = [4096, 11264, 19712, 32768, 0];
351
352/// Number of MV classes.
353pub const MV_CLASSES: usize = 11;
354
355/// Default CDF for MV class.
356pub const DEFAULT_MV_CLASS_CDF: [u16; 12] = [
357    28672, 30976, 31744, 32128, 32320, 32448, 32544, 32608, 32672, 32720, 32768, 0,
358];
359
360/// Default CDF for MV class 0 bit.
361pub const DEFAULT_MV_CLASS0_BIT_CDF: [u16; 3] = [16384, 32768, 0];
362
363/// Number of MV class 0 fractional values.
364pub const MV_CLASS0_FP: usize = 4;
365
366/// Default CDF for MV class 0 fractional.
367pub const DEFAULT_MV_CLASS0_FP_CDF: [[u16; 5]; 2] = [
368    [8192, 16384, 24576, 32768, 0],
369    [8192, 16384, 24576, 32768, 0],
370];
371
372/// Number of MV fractional values.
373pub const MV_FP: usize = 4;
374
375/// Default CDF for MV fractional.
376pub const DEFAULT_MV_FP_CDF: [u16; 5] = [8192, 16384, 24576, 32768, 0];
377
378/// Default CDF for MV class 0 high precision.
379pub const DEFAULT_MV_CLASS0_HP_CDF: [u16; 3] = [16384, 32768, 0];
380
381/// Default CDF for MV high precision.
382pub const DEFAULT_MV_HP_CDF: [u16; 3] = [16384, 32768, 0];
383
384/// Default CDF for MV sign.
385pub const DEFAULT_MV_SIGN_CDF: [u16; 3] = [16384, 32768, 0];
386
387/// Number of MV bits for class > 0.
388pub const MV_OFFSET_BITS: usize = 10;
389
390/// Default CDF for MV bits.
391pub const DEFAULT_MV_BITS_CDF: [[u16; 3]; MV_OFFSET_BITS] = [
392    [16384, 32768, 0],
393    [16384, 32768, 0],
394    [16384, 32768, 0],
395    [16384, 32768, 0],
396    [16384, 32768, 0],
397    [16384, 32768, 0],
398    [16384, 32768, 0],
399    [16384, 32768, 0],
400    [16384, 32768, 0],
401    [16384, 32768, 0],
402];
403
404// =============================================================================
405// Skip CDFs
406// =============================================================================
407
408/// Number of skip contexts.
409pub const SKIP_CONTEXTS: usize = 3;
410
411/// Default CDF for skip.
412pub const DEFAULT_SKIP_CDF: [[u16; 3]; SKIP_CONTEXTS] =
413    [[24576, 32768, 0], [16384, 32768, 0], [8192, 32768, 0]];
414
415// =============================================================================
416// Segment CDFs
417// =============================================================================
418
419/// Maximum number of segments.
420pub const MAX_SEGMENTS: usize = 8;
421
422/// Default CDF for segment ID (tree).
423pub const DEFAULT_SEGMENT_TREE_CDF: [u16; 9] =
424    [4096, 8192, 12288, 16384, 20480, 24576, 28672, 32768, 0];
425
426/// Default CDF for segment ID prediction.
427pub const DEFAULT_SEGMENT_PRED_CDF: [[u16; 3]; 3] =
428    [[16384, 32768, 0], [16384, 32768, 0], [16384, 32768, 0]];
429
430// =============================================================================
431// Reference Frame CDFs
432// =============================================================================
433
434/// Number of reference frame contexts.
435pub const REF_CONTEXTS: usize = 3;
436
437/// Number of reference frame types for single ref.
438pub const SINGLE_REF_TYPES: usize = 7;
439
440/// Default CDF for single reference frame.
441pub const DEFAULT_SINGLE_REF_CDF: [[[u16; 3]; SINGLE_REF_TYPES]; REF_CONTEXTS] = [
442    [
443        [16384, 32768, 0],
444        [16384, 32768, 0],
445        [16384, 32768, 0],
446        [16384, 32768, 0],
447        [16384, 32768, 0],
448        [16384, 32768, 0],
449        [16384, 32768, 0],
450    ],
451    [
452        [16384, 32768, 0],
453        [16384, 32768, 0],
454        [16384, 32768, 0],
455        [16384, 32768, 0],
456        [16384, 32768, 0],
457        [16384, 32768, 0],
458        [16384, 32768, 0],
459    ],
460    [
461        [16384, 32768, 0],
462        [16384, 32768, 0],
463        [16384, 32768, 0],
464        [16384, 32768, 0],
465        [16384, 32768, 0],
466        [16384, 32768, 0],
467        [16384, 32768, 0],
468    ],
469];
470
471// =============================================================================
472// Inter Mode CDFs
473// =============================================================================
474
475/// Number of inter mode contexts.
476pub const INTER_MODE_CONTEXTS: usize = 8;
477
478/// Number of inter modes.
479pub const INTER_MODES: usize = 4;
480
481/// Default CDF for inter mode.
482pub const DEFAULT_INTER_MODE_CDF: [[u16; 5]; INTER_MODE_CONTEXTS] = [
483    [2048, 10240, 17664, 32768, 0],
484    [4096, 12288, 20480, 32768, 0],
485    [6144, 14336, 22528, 32768, 0],
486    [8192, 16384, 24576, 32768, 0],
487    [10240, 18432, 26624, 32768, 0],
488    [12288, 20480, 28672, 32768, 0],
489    [14336, 22528, 29696, 32768, 0],
490    [16384, 24576, 30720, 32768, 0],
491];
492
493// =============================================================================
494// Compound Mode CDFs
495// =============================================================================
496
497/// Number of compound mode contexts.
498pub const COMPOUND_MODE_CONTEXTS: usize = 8;
499
500/// Number of compound modes.
501pub const COMPOUND_MODES: usize = 8;
502
503/// Default CDF for compound mode.
504pub const DEFAULT_COMPOUND_MODE_CDF: [[u16; 9]; COMPOUND_MODE_CONTEXTS] = [
505    [4096, 8192, 12288, 16384, 20480, 24576, 28672, 32768, 0],
506    [4096, 8192, 12288, 16384, 20480, 24576, 28672, 32768, 0],
507    [4096, 8192, 12288, 16384, 20480, 24576, 28672, 32768, 0],
508    [4096, 8192, 12288, 16384, 20480, 24576, 28672, 32768, 0],
509    [4096, 8192, 12288, 16384, 20480, 24576, 28672, 32768, 0],
510    [4096, 8192, 12288, 16384, 20480, 24576, 28672, 32768, 0],
511    [4096, 8192, 12288, 16384, 20480, 24576, 28672, 32768, 0],
512    [4096, 8192, 12288, 16384, 20480, 24576, 28672, 32768, 0],
513];
514
515// =============================================================================
516// Filter CDFs
517// =============================================================================
518
519/// Number of interpolation filter types.
520pub const INTERP_FILTERS: usize = 4;
521
522/// Number of interpolation filter contexts.
523pub const INTERP_FILTER_CONTEXTS: usize = 16;
524
525/// Default CDF for interpolation filter.
526pub const DEFAULT_INTERP_FILTER_CDF: [[u16; 5]; INTERP_FILTER_CONTEXTS] = [
527    [6144, 12288, 18432, 32768, 0],
528    [6144, 12288, 18432, 32768, 0],
529    [6144, 12288, 18432, 32768, 0],
530    [6144, 12288, 18432, 32768, 0],
531    [8192, 16384, 24576, 32768, 0],
532    [8192, 16384, 24576, 32768, 0],
533    [8192, 16384, 24576, 32768, 0],
534    [8192, 16384, 24576, 32768, 0],
535    [10240, 18432, 26624, 32768, 0],
536    [10240, 18432, 26624, 32768, 0],
537    [10240, 18432, 26624, 32768, 0],
538    [10240, 18432, 26624, 32768, 0],
539    [12288, 20480, 28672, 32768, 0],
540    [12288, 20480, 28672, 32768, 0],
541    [12288, 20480, 28672, 32768, 0],
542    [12288, 20480, 28672, 32768, 0],
543];
544
545// =============================================================================
546// CDF Helper Functions
547// =============================================================================
548
549/// Create a uniform CDF for n symbols.
550#[must_use]
551#[allow(clippy::cast_possible_truncation)]
552pub fn create_uniform_cdf(n: usize) -> Vec<u16> {
553    let mut cdf = Vec::with_capacity(n + 1);
554    for i in 1..=n {
555        cdf.push(((i * CDF_PROB_TOP as usize) / n) as u16);
556    }
557    cdf.push(CDF_INIT_COUNT); // Symbol count
558    cdf
559}
560
561/// Update CDF after observing a symbol.
562#[allow(clippy::cast_possible_truncation)]
563pub fn update_cdf(cdf: &mut [u16], symbol: usize) {
564    let n = cdf.len() - 1; // Last element is count
565    if n == 0 {
566        return;
567    }
568
569    // Get current count and compute rate
570    let count = u32::from(cdf[n]);
571    let rate = 3 + (count >> 4);
572    let rate = rate.min(32);
573
574    // Update CDF values
575    for i in 0..n {
576        if i < symbol {
577            // Decrease probability
578            let diff = cdf[i] >> rate;
579            cdf[i] = cdf[i].saturating_sub(diff);
580        } else {
581            // Increase probability
582            let diff = CDF_PROB_TOP.saturating_sub(cdf[i]) >> rate;
583            cdf[i] = cdf[i].saturating_add(diff);
584        }
585    }
586
587    // Increment count
588    if count < u32::from(CDF_MAX_COUNT) {
589        cdf[n] += 1;
590    }
591}
592
593/// Reset CDF to uniform distribution.
594#[allow(clippy::cast_possible_truncation)]
595pub fn reset_cdf_uniform(cdf: &mut [u16]) {
596    let n = cdf.len() - 1;
597    if n == 0 {
598        return;
599    }
600
601    for i in 0..n {
602        cdf[i] = (((i + 1) * CDF_PROB_TOP as usize) / n) as u16;
603    }
604    cdf[n] = CDF_INIT_COUNT;
605}
606
607/// Copy CDF from source to destination.
608pub fn copy_cdf(dst: &mut [u16], src: &[u16]) {
609    let len = dst.len().min(src.len());
610    dst[..len].copy_from_slice(&src[..len]);
611}
612
613/// Check if a CDF is valid (monotonically increasing, ends at CDF_PROB_TOP).
614#[must_use]
615pub fn is_valid_cdf(cdf: &[u16]) -> bool {
616    if cdf.len() < 2 {
617        return false;
618    }
619
620    let n = cdf.len() - 1;
621
622    // Check monotonicity
623    for i in 1..n {
624        if cdf[i] < cdf[i - 1] {
625            return false;
626        }
627    }
628
629    // Last probability should be CDF_PROB_TOP
630    cdf[n - 1] == CDF_PROB_TOP
631}
632
633// =============================================================================
634// CDF Context Management
635// =============================================================================
636
637/// Container for all CDF tables used in decoding.
638#[derive(Clone, Debug)]
639pub struct CdfContext {
640    /// Partition CDFs.
641    pub partition: [[u16; 11]; PARTITION_CONTEXTS],
642    /// Y intra mode CDFs.
643    pub y_mode: [[u16; 14]; INTRA_Y_MODE_CONTEXTS],
644    /// Skip CDFs.
645    pub skip: [[u16; 3]; SKIP_CONTEXTS],
646    /// MV joint CDF.
647    pub mv_joint: [u16; 5],
648    /// MV sign CDFs (for each component).
649    pub mv_sign: [[u16; 3]; 2],
650    /// MV class CDFs.
651    pub mv_class: [[u16; 12]; 2],
652    /// DC sign CDFs.
653    pub dc_sign: [[u16; 3]; DC_SIGN_CTX_COUNT],
654    /// Coefficient base CDFs.
655    pub coeff_base: Vec<[u16; 5]>,
656    /// Coefficient base range CDFs.
657    pub coeff_br: Vec<[u16; 4]>,
658    /// EOB multi-symbol CDFs.
659    ///
660    /// Indexed by `ctx = tx_size_idx * EOB_PLANE_COUNT + plane` (see
661    /// [`EOB_MULTI_TOTAL_CONTEXTS`]). Each inner `Vec<u16>` is a CDF whose
662    /// length is selected per transform area — `_8` (9 entries) for small
663    /// blocks (area ≤ 64) and `_16` (17 entries) for larger blocks.
664    pub eob_multi: Vec<Vec<u16>>,
665}
666
667impl CdfContext {
668    /// Create a new CDF context with default values.
669    #[must_use]
670    pub fn new() -> Self {
671        let eob_multi = (0..EOB_MULTI_TOTAL_CONTEXTS)
672            .map(|ctx| default_eob_multi_for_tx_size_idx(ctx / EOB_PLANE_COUNT))
673            .collect();
674
675        Self {
676            partition: DEFAULT_PARTITION_CDFS,
677            y_mode: DEFAULT_Y_MODE_CDFS,
678            skip: DEFAULT_SKIP_CDF,
679            mv_joint: DEFAULT_MV_JOINT_CDF,
680            mv_sign: [DEFAULT_MV_SIGN_CDF, DEFAULT_MV_SIGN_CDF],
681            mv_class: [DEFAULT_MV_CLASS_CDF, DEFAULT_MV_CLASS_CDF],
682            dc_sign: [DEFAULT_DC_SIGN_CDF; DC_SIGN_CTX_COUNT],
683            coeff_base: vec![DEFAULT_COEFF_BASE_CDF; COEFF_BASE_CTX_COUNT],
684            coeff_br: vec![DEFAULT_COEFF_BR_CDF; COEFF_BR_CTX_COUNT],
685            eob_multi,
686        }
687    }
688
689    /// Reset all CDFs to default values.
690    pub fn reset(&mut self) {
691        self.partition = DEFAULT_PARTITION_CDFS;
692        self.y_mode = DEFAULT_Y_MODE_CDFS;
693        self.skip = DEFAULT_SKIP_CDF;
694        self.mv_joint = DEFAULT_MV_JOINT_CDF;
695        self.mv_sign = [DEFAULT_MV_SIGN_CDF, DEFAULT_MV_SIGN_CDF];
696        self.mv_class = [DEFAULT_MV_CLASS_CDF, DEFAULT_MV_CLASS_CDF];
697        self.dc_sign = [DEFAULT_DC_SIGN_CDF; DC_SIGN_CTX_COUNT];
698
699        for cdf in &mut self.coeff_base {
700            *cdf = DEFAULT_COEFF_BASE_CDF;
701        }
702
703        for cdf in &mut self.coeff_br {
704            *cdf = DEFAULT_COEFF_BR_CDF;
705        }
706
707        // Rebuild EOB multi CDFs so each ctx is reset to the proper default
708        // for its `(tx_size, plane)` pair. We resize first to guarantee
709        // length invariants even if a caller previously shrank the Vec.
710        self.eob_multi
711            .resize_with(EOB_MULTI_TOTAL_CONTEXTS, Vec::new);
712        for (ctx, cdf) in self.eob_multi.iter_mut().enumerate() {
713            *cdf = default_eob_multi_for_tx_size_idx(ctx / EOB_PLANE_COUNT);
714        }
715    }
716
717    /// Get partition CDF for a context.
718    #[must_use]
719    pub fn get_partition_cdf(&self, ctx: usize) -> &[u16; 11] {
720        &self.partition[ctx.min(PARTITION_CONTEXTS - 1)]
721    }
722
723    /// Get mutable partition CDF for a context.
724    pub fn get_partition_cdf_mut(&mut self, ctx: usize) -> &mut [u16; 11] {
725        &mut self.partition[ctx.min(PARTITION_CONTEXTS - 1)]
726    }
727
728    /// Get Y mode CDF for a context.
729    #[must_use]
730    pub fn get_y_mode_cdf(&self, ctx: usize) -> &[u16; 14] {
731        &self.y_mode[ctx.min(INTRA_Y_MODE_CONTEXTS - 1)]
732    }
733
734    /// Get skip CDF for a context.
735    #[must_use]
736    pub fn get_skip_cdf(&self, ctx: usize) -> &[u16; 3] {
737        &self.skip[ctx.min(SKIP_CONTEXTS - 1)]
738    }
739
740    /// Get EOB multi CDF for a context.
741    ///
742    /// The context is `tx_size_idx * EOB_PLANE_COUNT + plane`. Indices that
743    /// fall outside [`EOB_MULTI_TOTAL_CONTEXTS`] are clamped to the final
744    /// slot, which holds the largest-area default CDF (see
745    /// [`default_eob_multi_for_tx_size_idx`]).
746    #[must_use]
747    pub fn get_eob_multi_cdf(&self, ctx: usize) -> &[u16] {
748        let idx = ctx.min(self.eob_multi.len().saturating_sub(1));
749        // The Vec is always populated in `new()`/`reset()`, but defend
750        // against degenerate states without unwrap.
751        match self.eob_multi.get(idx) {
752            Some(cdf) => cdf.as_slice(),
753            None => &DEFAULT_EOB_MULTI_16,
754        }
755    }
756
757    /// Get coefficient base CDF for a context.
758    #[must_use]
759    pub fn get_coeff_base_cdf(&self, ctx: usize) -> &[u16] {
760        if ctx < self.coeff_base.len() {
761            &self.coeff_base[ctx]
762        } else {
763            &DEFAULT_COEFF_BASE_CDF
764        }
765    }
766
767    /// Get coefficient base EOB CDF for a context.
768    #[must_use]
769    pub fn get_coeff_base_eob_cdf(&self, ctx: usize) -> &[u16] {
770        // For EOB position, use the same as coeff_base
771        self.get_coeff_base_cdf(ctx)
772    }
773
774    /// Get coefficient BR (base range) CDF for a context.
775    #[must_use]
776    pub fn get_coeff_br_cdf(&self, ctx: usize) -> &[u16] {
777        if ctx < self.coeff_br.len() {
778            &self.coeff_br[ctx]
779        } else {
780            &DEFAULT_COEFF_BR_CDF
781        }
782    }
783
784    /// Get DC sign CDF for a context.
785    #[must_use]
786    pub fn get_dc_sign_cdf(&self, ctx: usize) -> &[u16] {
787        &self.dc_sign[ctx.min(DC_SIGN_CTX_COUNT - 1)]
788    }
789
790    // Mutable versions for updating CDFs during decoding
791
792    /// Get mutable EOB multi CDF for a context.
793    ///
794    /// Mirrors [`Self::get_eob_multi_cdf`] but yields a mutable slice so
795    /// the arithmetic coder can adapt the CDF after each symbol. Returns a
796    /// slice into the largest-area slot when `ctx` is out of range.
797    pub fn get_eob_multi_cdf_mut(&mut self, ctx: usize) -> &mut [u16] {
798        if self.eob_multi.is_empty() {
799            // Guarantee the invariant: rebuild from defaults rather than
800            // returning an empty slice that would crash arithmetic coding.
801            self.eob_multi = (0..EOB_MULTI_TOTAL_CONTEXTS)
802                .map(|c| default_eob_multi_for_tx_size_idx(c / EOB_PLANE_COUNT))
803                .collect();
804        }
805        let len = self.eob_multi.len();
806        let idx = ctx.min(len - 1);
807        // `idx < len` after the clamp above, so the indexed access is safe.
808        // We use direct indexing because `get_mut` returns `Option<&mut Vec>`
809        // which complicates the slice return without unwrap; the bound is
810        // already enforced.
811        self.eob_multi[idx].as_mut_slice()
812    }
813
814    /// Get mutable coefficient base CDF for a context.
815    pub fn get_coeff_base_cdf_mut(&mut self, ctx: usize) -> &mut [u16] {
816        if ctx < self.coeff_base.len() {
817            &mut self.coeff_base[ctx]
818        } else if !self.coeff_base.is_empty() {
819            &mut self.coeff_base[0]
820        } else {
821            &mut []
822        }
823    }
824
825    /// Get mutable coefficient base EOB CDF for a context.
826    pub fn get_coeff_base_eob_cdf_mut(&mut self, ctx: usize) -> &mut [u16] {
827        self.get_coeff_base_cdf_mut(ctx)
828    }
829
830    /// Get mutable coefficient BR (base range) CDF for a context.
831    pub fn get_coeff_br_cdf_mut(&mut self, ctx: usize) -> &mut [u16] {
832        if ctx < self.coeff_br.len() {
833            &mut self.coeff_br[ctx]
834        } else if !self.coeff_br.is_empty() {
835            &mut self.coeff_br[0]
836        } else {
837            &mut []
838        }
839    }
840
841    /// Get mutable DC sign CDF for a context.
842    pub fn get_dc_sign_cdf_mut(&mut self, ctx: usize) -> &mut [u16] {
843        let idx = ctx.min(DC_SIGN_CTX_COUNT - 1);
844        &mut self.dc_sign[idx]
845    }
846}
847
848impl Default for CdfContext {
849    fn default() -> Self {
850        Self::new()
851    }
852}
853
854// =============================================================================
855// Tests
856// =============================================================================
857
858#[cfg(test)]
859mod tests {
860    use super::*;
861
862    #[test]
863    fn test_create_uniform_cdf() {
864        let cdf = create_uniform_cdf(4);
865        assert_eq!(cdf.len(), 5);
866        assert_eq!(cdf[0], 8192);
867        assert_eq!(cdf[1], 16384);
868        assert_eq!(cdf[2], 24576);
869        assert_eq!(cdf[3], 32768);
870        assert_eq!(cdf[4], 0); // Count
871    }
872
873    #[test]
874    fn test_update_cdf() {
875        let mut cdf = create_uniform_cdf(4);
876
877        // Update with symbol 0 should increase its probability
878        let orig_0 = cdf[0];
879        update_cdf(&mut cdf, 0);
880        assert!(cdf[0] >= orig_0);
881    }
882
883    #[test]
884    fn test_reset_cdf_uniform() {
885        let mut cdf = vec![0u16; 5];
886        reset_cdf_uniform(&mut cdf);
887
888        assert_eq!(cdf[0], 8192);
889        assert_eq!(cdf[3], 32768);
890        assert_eq!(cdf[4], 0);
891    }
892
893    #[test]
894    fn test_copy_cdf() {
895        let src = create_uniform_cdf(4);
896        let mut dst = vec![0u16; 5];
897
898        copy_cdf(&mut dst, &src);
899        assert_eq!(dst, src);
900    }
901
902    #[test]
903    fn test_is_valid_cdf() {
904        let valid_cdf = create_uniform_cdf(4);
905        assert!(is_valid_cdf(&valid_cdf));
906
907        let invalid_cdf = vec![100u16, 50, 200, 32768, 0]; // Not monotonic
908        assert!(!is_valid_cdf(&invalid_cdf));
909    }
910
911    #[test]
912    fn test_cdf_context_new() {
913        let ctx = CdfContext::new();
914        assert_eq!(ctx.partition.len(), PARTITION_CONTEXTS);
915        assert_eq!(ctx.y_mode.len(), INTRA_Y_MODE_CONTEXTS);
916    }
917
918    #[test]
919    fn test_cdf_context_reset() {
920        let mut ctx = CdfContext::new();
921        ctx.partition[0][0] = 12345;
922
923        ctx.reset();
924        assert_eq!(ctx.partition[0], DEFAULT_PARTITION_CDFS[0]);
925    }
926
927    #[test]
928    fn test_get_partition_cdf() {
929        let ctx = CdfContext::new();
930
931        let cdf = ctx.get_partition_cdf(0);
932        assert_eq!(cdf, &DEFAULT_PARTITION_CDF_0);
933
934        // Out of bounds should clamp
935        let cdf_clamped = ctx.get_partition_cdf(100);
936        assert_eq!(cdf_clamped, &DEFAULT_PARTITION_CDF_3);
937    }
938
939    #[test]
940    fn test_default_cdfs_valid() {
941        // Check that default CDFs are valid
942        for cdf in &DEFAULT_PARTITION_CDFS {
943            assert!(is_valid_cdf(cdf));
944        }
945
946        for cdf in &DEFAULT_Y_MODE_CDFS {
947            assert!(is_valid_cdf(cdf));
948        }
949
950        assert!(is_valid_cdf(&DEFAULT_MV_JOINT_CDF));
951        assert!(is_valid_cdf(&DEFAULT_MV_CLASS_CDF));
952    }
953
954    #[test]
955    fn test_cdf_constants() {
956        assert_eq!(CDF_PROB_BITS, 15);
957        assert_eq!(CDF_PROB_TOP, 32768);
958        assert_eq!(CDF_MAX_COUNT, 32);
959    }
960
961    #[test]
962    fn test_partition_contexts() {
963        assert_eq!(PARTITION_CONTEXTS, 4);
964        assert_eq!(PARTITION_TYPES, 10);
965    }
966
967    #[test]
968    fn test_intra_mode_contexts() {
969        assert_eq!(INTRA_MODES, 13);
970        assert_eq!(INTRA_Y_MODE_CONTEXTS, 4);
971    }
972
973    #[test]
974    fn test_mv_constants() {
975        assert_eq!(MV_JOINTS, 4);
976        assert_eq!(MV_CLASSES, 11);
977        assert_eq!(MV_OFFSET_BITS, 10);
978    }
979
980    // -------------------------------------------------------------------
981    // EOB multi-symbol CDF context routing tests
982    // -------------------------------------------------------------------
983
984    /// Tx-size indices match the `TxSize` enum order (0..=18).
985    const TX_4X4: usize = 0;
986    const TX_8X8: usize = 1;
987    const TX_16X16: usize = 2;
988    const TX_32X32: usize = 3;
989    const TX_64X64: usize = 4;
990    const TX_4X8: usize = 5;
991    const TX_8X4: usize = 6;
992    const TX_8X16: usize = 7;
993    const TX_16X8: usize = 8;
994    const TX_16X32: usize = 9;
995    const TX_32X16: usize = 10;
996    const TX_4X16: usize = 13;
997    const TX_16X4: usize = 14;
998
999    fn eob_ctx_for(tx_size_idx: usize, plane: usize) -> usize {
1000        tx_size_idx * EOB_PLANE_COUNT + plane
1001    }
1002
1003    #[test]
1004    fn test_eob_multi_total_contexts() {
1005        // 19 transform sizes × 3 plane types = 57 contexts.
1006        assert_eq!(EOB_MULTI_TOTAL_CONTEXTS, 57);
1007        assert_eq!(TX_SIZE_COUNT, 19);
1008        assert_eq!(EOB_PLANE_COUNT, 3);
1009    }
1010
1011    #[test]
1012    fn test_eob_multi_vec_populated() {
1013        let ctx = CdfContext::new();
1014        assert_eq!(ctx.eob_multi.len(), EOB_MULTI_TOTAL_CONTEXTS);
1015        for cdf in &ctx.eob_multi {
1016            assert!(
1017                cdf.len() == DEFAULT_EOB_MULTI_8.len() || cdf.len() == DEFAULT_EOB_MULTI_16.len(),
1018                "EOB CDF must be 9 or 17 entries, got {}",
1019                cdf.len()
1020            );
1021        }
1022    }
1023
1024    #[test]
1025    fn test_eob_multi_cdf_length_small_blocks() {
1026        // Areas ≤ 64 (4×4, 4×8, 8×4, 8×8, 4×16, 16×4) use the 9-entry CDF.
1027        let ctx = CdfContext::new();
1028        for plane in 0..EOB_PLANE_COUNT {
1029            for tx in [TX_4X4, TX_4X8, TX_8X4, TX_8X8, TX_4X16, TX_16X4] {
1030                let cdf = ctx.get_eob_multi_cdf(eob_ctx_for(tx, plane));
1031                assert_eq!(
1032                    cdf.len(),
1033                    DEFAULT_EOB_MULTI_8.len(),
1034                    "tx_size_idx={tx} plane={plane} expected 9-entry CDF"
1035                );
1036                assert_eq!(cdf, &DEFAULT_EOB_MULTI_8[..]);
1037            }
1038        }
1039    }
1040
1041    #[test]
1042    fn test_eob_multi_cdf_length_large_blocks() {
1043        // Areas ≥ 128 use the 17-entry CDF.
1044        let ctx = CdfContext::new();
1045        for plane in 0..EOB_PLANE_COUNT {
1046            for tx in [
1047                TX_8X16, TX_16X8, TX_16X16, TX_16X32, TX_32X16, TX_32X32, TX_64X64,
1048            ] {
1049                let cdf = ctx.get_eob_multi_cdf(eob_ctx_for(tx, plane));
1050                assert_eq!(
1051                    cdf.len(),
1052                    DEFAULT_EOB_MULTI_16.len(),
1053                    "tx_size_idx={tx} plane={plane} expected 17-entry CDF"
1054                );
1055                assert_eq!(cdf, &DEFAULT_EOB_MULTI_16[..]);
1056            }
1057        }
1058    }
1059
1060    #[test]
1061    fn test_eob_multi_cdf_out_of_range_clamps() {
1062        let ctx = CdfContext::new();
1063        // ctx values ≥ EOB_MULTI_TOTAL_CONTEXTS clamp to the last slot.
1064        let clamped = ctx.get_eob_multi_cdf(EOB_MULTI_TOTAL_CONTEXTS + 100);
1065        let last = ctx.get_eob_multi_cdf(EOB_MULTI_TOTAL_CONTEXTS - 1);
1066        assert_eq!(clamped, last);
1067    }
1068
1069    #[test]
1070    fn test_eob_multi_cdf_adapts_via_mut() {
1071        // Mutate via the mutable getter and verify persistence through the
1072        // immutable getter. This is what the arithmetic coder relies on.
1073        let mut ctx = CdfContext::new();
1074        let key = eob_ctx_for(TX_4X4, 0);
1075
1076        let initial = ctx.get_eob_multi_cdf(key).to_vec();
1077        {
1078            let cdf_mut = ctx.get_eob_multi_cdf_mut(key);
1079            assert_eq!(cdf_mut.len(), DEFAULT_EOB_MULTI_8.len());
1080            let len = cdf_mut.len();
1081            // Bump every non-terminator entry; terminator is the count.
1082            for v in &mut cdf_mut[..len - 1] {
1083                *v = v.saturating_add(1);
1084            }
1085            cdf_mut[len - 1] = 5; // simulate adaptation count
1086        }
1087        let after = ctx.get_eob_multi_cdf(key).to_vec();
1088        assert_ne!(
1089            initial, after,
1090            "mutation must be visible via immutable getter"
1091        );
1092        assert_eq!(after.len(), initial.len());
1093        assert_eq!(after[after.len() - 1], 5);
1094    }
1095
1096    #[test]
1097    fn test_eob_multi_distinct_slots_per_plane() {
1098        // Mutating one (tx_size, plane) must not affect a different plane.
1099        let mut ctx = CdfContext::new();
1100        let key_luma = eob_ctx_for(TX_16X16, 0);
1101        let key_chroma = eob_ctx_for(TX_16X16, 1);
1102
1103        let chroma_before = ctx.get_eob_multi_cdf(key_chroma).to_vec();
1104        {
1105            let luma_mut = ctx.get_eob_multi_cdf_mut(key_luma);
1106            luma_mut[0] = 42;
1107        }
1108        let chroma_after = ctx.get_eob_multi_cdf(key_chroma).to_vec();
1109        let luma_after = ctx.get_eob_multi_cdf(key_luma).to_vec();
1110
1111        assert_eq!(chroma_before, chroma_after, "chroma slot untouched");
1112        assert_eq!(luma_after[0], 42, "luma mutation visible");
1113    }
1114
1115    #[test]
1116    fn test_eob_multi_reset_restores_defaults() {
1117        let mut ctx = CdfContext::new();
1118        let key = eob_ctx_for(TX_32X32, 2);
1119        {
1120            let cdf_mut = ctx.get_eob_multi_cdf_mut(key);
1121            cdf_mut[0] = 99;
1122            cdf_mut[1] = 100;
1123        }
1124        ctx.reset();
1125        let after = ctx.get_eob_multi_cdf(key);
1126        assert_eq!(
1127            after,
1128            &DEFAULT_EOB_MULTI_16[..],
1129            "reset must rebuild large-block default"
1130        );
1131    }
1132
1133    #[test]
1134    fn test_eob_multi_cdf_mut_recovers_after_clear() {
1135        // Defensive: if external code drains `eob_multi`, the mutable
1136        // getter must rebuild the table rather than panic / unwrap.
1137        let mut ctx = CdfContext::new();
1138        ctx.eob_multi.clear();
1139        let cdf = ctx.get_eob_multi_cdf_mut(0);
1140        assert!(!cdf.is_empty());
1141        assert_eq!(ctx.eob_multi.len(), EOB_MULTI_TOTAL_CONTEXTS);
1142    }
1143
1144    #[test]
1145    fn test_eob_multi_default_cdfs_are_valid() {
1146        // The EOB multi defaults must be monotonic and terminate at the
1147        // probability cap.
1148        assert!(is_valid_cdf(&DEFAULT_EOB_MULTI_2));
1149        assert!(is_valid_cdf(&DEFAULT_EOB_MULTI_4));
1150        assert!(is_valid_cdf(&DEFAULT_EOB_MULTI_8));
1151        assert!(is_valid_cdf(&DEFAULT_EOB_MULTI_16));
1152    }
1153}