Skip to main content

oximedia_codec/av1/
entropy_tables.rs

1//! AV1 entropy coding default CDF tables.
2//!
3//! This module contains the default probability tables (CDFs) used for
4//! entropy coding in AV1. These tables are used as initial values and
5//! are updated during decoding based on observed symbol frequencies.
6//!
7//! # CDF Format
8//!
9//! CDFs are stored as arrays of 16-bit unsigned integers. The last element
10//! is reserved for the symbol count used in CDF update. The actual
11//! probabilities are in 15-bit precision (0-32767).
12//!
13//! # Table Categories
14//!
15//! - **Partition CDFs** - Block partitioning decisions
16//! - **Intra Mode CDFs** - Intra prediction mode selection
17//! - **TX Size CDFs** - Transform size selection
18//! - **TX Type CDFs** - Transform type selection
19//! - **Coefficient CDFs** - Coefficient level and sign coding
20//! - **MV Component CDFs** - Motion vector component coding
21//!
22//! # Reference
23//!
24//! See AV1 Specification Section 9 for probability model initialization
25//! and update procedures.
26
27#![allow(dead_code)]
28#![allow(clippy::doc_markdown)]
29#![allow(clippy::needless_range_loop)]
30
31// =============================================================================
32// CDF Precision Constants
33// =============================================================================
34
35/// CDF precision bits.
36pub const CDF_PROB_BITS: u8 = 15;
37
38/// Maximum CDF probability value.
39pub const CDF_PROB_TOP: u16 = 1 << CDF_PROB_BITS;
40
41/// Initial symbol count for CDF adaptation.
42pub const CDF_INIT_COUNT: u16 = 0;
43
44/// Maximum symbol count for CDF adaptation rate.
45pub const CDF_MAX_COUNT: u16 = 32;
46
47// =============================================================================
48// Partition CDFs
49// =============================================================================
50
51/// Number of partition contexts.
52pub const PARTITION_CONTEXTS: usize = 4;
53
54/// Number of partition types.
55pub const PARTITION_TYPES: usize = 10;
56
57/// Default CDF for partition (context 0, small blocks).
58pub const DEFAULT_PARTITION_CDF_0: [u16; 11] = [
59    15588, 17570, 19323, 21084, 22472, 24311, 25744, 27999, 29223, 32768, 0,
60];
61
62/// Default CDF for partition (context 1, medium blocks).
63pub const DEFAULT_PARTITION_CDF_1: [u16; 11] = [
64    12064, 14616, 17239, 19824, 21631, 24068, 25919, 28400, 29760, 32768, 0,
65];
66
67/// Default CDF for partition (context 2, large blocks).
68pub const DEFAULT_PARTITION_CDF_2: [u16; 11] = [
69    9216, 12096, 15424, 18432, 20672, 23424, 25664, 28544, 30080, 32768, 0,
70];
71
72/// Default CDF for partition (context 3, very large blocks).
73pub const DEFAULT_PARTITION_CDF_3: [u16; 11] = [
74    6144, 9472, 13312, 16896, 19584, 22912, 25600, 28800, 30464, 32768, 0,
75];
76
77/// All default partition CDFs.
78pub const DEFAULT_PARTITION_CDFS: [[u16; 11]; PARTITION_CONTEXTS] = [
79    DEFAULT_PARTITION_CDF_0,
80    DEFAULT_PARTITION_CDF_1,
81    DEFAULT_PARTITION_CDF_2,
82    DEFAULT_PARTITION_CDF_3,
83];
84
85// =============================================================================
86// Intra Mode CDFs
87// =============================================================================
88
89/// Number of intra modes.
90pub const INTRA_MODES: usize = 13;
91
92/// Number of intra mode contexts for Y.
93pub const INTRA_Y_MODE_CONTEXTS: usize = 4;
94
95/// Default CDF for Y intra mode (context 0).
96pub const DEFAULT_Y_MODE_CDF_0: [u16; 14] = [
97    15588, 17570, 18800, 20000, 21500, 23000, 24500, 26000, 27500, 29000, 30500, 31500, 32768, 0,
98];
99
100/// Default CDF for Y intra mode (context 1).
101pub const DEFAULT_Y_MODE_CDF_1: [u16; 14] = [
102    12064, 14616, 16500, 18500, 20500, 22500, 24500, 26500, 28000, 29500, 30750, 31750, 32768, 0,
103];
104
105/// Default CDF for Y intra mode (context 2).
106pub const DEFAULT_Y_MODE_CDF_2: [u16; 14] = [
107    9216, 12096, 14500, 17000, 19500, 22000, 24500, 26500, 28500, 30000, 31000, 32000, 32768, 0,
108];
109
110/// Default CDF for Y intra mode (context 3).
111pub const DEFAULT_Y_MODE_CDF_3: [u16; 14] = [
112    6144, 9472, 12500, 15500, 18500, 21500, 24500, 27000, 29000, 30500, 31250, 32000, 32768, 0,
113];
114
115/// All default Y mode CDFs.
116pub const DEFAULT_Y_MODE_CDFS: [[u16; 14]; INTRA_Y_MODE_CONTEXTS] = [
117    DEFAULT_Y_MODE_CDF_0,
118    DEFAULT_Y_MODE_CDF_1,
119    DEFAULT_Y_MODE_CDF_2,
120    DEFAULT_Y_MODE_CDF_3,
121];
122
123/// Number of UV intra mode contexts.
124pub const INTRA_UV_MODE_CONTEXTS: usize = 13;
125
126/// Default CDF for UV intra mode (for CFL disabled).
127pub const DEFAULT_UV_MODE_CDF_NO_CFL: [u16; 14] = [
128    22528, 24320, 25344, 26368, 27136, 28160, 28928, 29696, 30464, 31104, 31616, 32128, 32768, 0,
129];
130
131/// Default CDF for UV intra mode (for CFL enabled).
132pub const DEFAULT_UV_MODE_CDF_CFL: [u16; 15] = [
133    18432, 20480, 22016, 23296, 24576, 25856, 27136, 28160, 29184, 30080, 30848, 31488, 32000,
134    32768, 0,
135];
136
137// =============================================================================
138// Transform Size CDFs
139// =============================================================================
140
141/// Number of TX size contexts.
142pub const TX_SIZE_CONTEXTS: usize = 3;
143
144/// Number of max TX size categories.
145pub const MAX_TX_CATS: usize = 4;
146
147/// Default CDF for TX size (max 8x8).
148pub const DEFAULT_TX_SIZE_CDF_8X8: [u16; 3] = [16384, 32768, 0];
149
150/// Default CDF for TX size (max 16x16).
151pub const DEFAULT_TX_SIZE_CDF_16X16: [u16; 4] = [10923, 21845, 32768, 0];
152
153/// Default CDF for TX size (max 32x32).
154pub const DEFAULT_TX_SIZE_CDF_32X32: [u16; 5] = [8192, 16384, 24576, 32768, 0];
155
156/// Default CDF for TX size (max 64x64).
157pub const DEFAULT_TX_SIZE_CDF_64X64: [u16; 6] = [6554, 13107, 19661, 26214, 32768, 0];
158
159// =============================================================================
160// Transform Type CDFs
161// =============================================================================
162
163/// Number of TX type contexts per set.
164pub const TX_TYPE_CONTEXTS: usize = 7;
165
166/// Number of transform types for intra.
167pub const INTRA_TX_TYPES: usize = 7;
168
169/// Number of transform types for inter.
170pub const INTER_TX_TYPES: usize = 16;
171
172/// Default CDF for intra TX type (TX_4X4).
173pub const DEFAULT_INTRA_TX_TYPE_4X4: [[u16; 8]; TX_TYPE_CONTEXTS] = [
174    [5461, 10923, 16384, 21845, 24576, 27307, 30037, 32768],
175    [4681, 9362, 14043, 18725, 22118, 25512, 28905, 32768],
176    [4096, 8192, 12288, 16384, 20480, 24576, 28672, 32768],
177    [3641, 7282, 10923, 14564, 18893, 23222, 27551, 32768],
178    [3277, 6554, 9830, 13107, 17476, 21845, 26214, 32768],
179    [2979, 5958, 8937, 11916, 16213, 20511, 25228, 32768],
180    [2731, 5461, 8192, 10923, 15019, 19114, 24064, 32768],
181];
182
183/// Default CDF for intra TX type (TX_8X8).
184pub const DEFAULT_INTRA_TX_TYPE_8X8: [[u16; 8]; TX_TYPE_CONTEXTS] = [
185    [6144, 12288, 18432, 24576, 26624, 28672, 30720, 32768],
186    [5461, 10923, 16384, 21845, 24576, 27307, 30037, 32768],
187    [4915, 9830, 14745, 19660, 22938, 26214, 29491, 32768],
188    [4455, 8909, 13364, 17818, 21399, 24980, 28561, 32768],
189    [4096, 8192, 12288, 16384, 20070, 23756, 27852, 32768],
190    [3780, 7559, 11339, 15119, 18897, 22675, 27200, 32768],
191    [3495, 6991, 10486, 13981, 17827, 21673, 26600, 32768],
192];
193
194/// Default CDF for inter TX type.
195pub const DEFAULT_INTER_TX_TYPE: [[u16; 17]; TX_TYPE_CONTEXTS] = [
196    [
197        2048, 4096, 6144, 8192, 10240, 12288, 14336, 16384, 18432, 20480, 22528, 24576, 26624,
198        28672, 30720, 32768, 0,
199    ],
200    [
201        2048, 4096, 6144, 8192, 10240, 12288, 14336, 16384, 18432, 20480, 22528, 24576, 26624,
202        28672, 30720, 32768, 0,
203    ],
204    [
205        2048, 4096, 6144, 8192, 10240, 12288, 14336, 16384, 18432, 20480, 22528, 24576, 26624,
206        28672, 30720, 32768, 0,
207    ],
208    [
209        2048, 4096, 6144, 8192, 10240, 12288, 14336, 16384, 18432, 20480, 22528, 24576, 26624,
210        28672, 30720, 32768, 0,
211    ],
212    [
213        2048, 4096, 6144, 8192, 10240, 12288, 14336, 16384, 18432, 20480, 22528, 24576, 26624,
214        28672, 30720, 32768, 0,
215    ],
216    [
217        2048, 4096, 6144, 8192, 10240, 12288, 14336, 16384, 18432, 20480, 22528, 24576, 26624,
218        28672, 30720, 32768, 0,
219    ],
220    [
221        2048, 4096, 6144, 8192, 10240, 12288, 14336, 16384, 18432, 20480, 22528, 24576, 26624,
222        28672, 30720, 32768, 0,
223    ],
224];
225
226// =============================================================================
227// Coefficient CDFs
228// =============================================================================
229
230/// Number of EOB multi contexts per plane type.
231pub const EOB_MULTI_CONTEXTS: usize = 7;
232
233/// Default CDF for EOB multi (2 symbols).
234pub const DEFAULT_EOB_MULTI_2: [u16; 3] = [16384, 32768, 0];
235
236/// Default CDF for EOB multi (4 symbols).
237pub const DEFAULT_EOB_MULTI_4: [u16; 5] = [8192, 16384, 24576, 32768, 0];
238
239/// Default CDF for EOB multi (8 symbols).
240pub const DEFAULT_EOB_MULTI_8: [u16; 9] = [4096, 8192, 12288, 16384, 20480, 24576, 28672, 32768, 0];
241
242/// Default CDF for EOB multi (16 symbols).
243pub const DEFAULT_EOB_MULTI_16: [u16; 17] = [
244    2048, 4096, 6144, 8192, 10240, 12288, 14336, 16384, 18432, 20480, 22528, 24576, 26624, 28672,
245    30720, 32768, 0,
246];
247
248/// Number of coefficient base contexts.
249pub const COEFF_BASE_CTX_COUNT: usize = 42;
250
251/// Default CDF for coefficient base (4 levels: 0, 1, 2, >2).
252pub const DEFAULT_COEFF_BASE_CDF: [u16; 5] = [8192, 16384, 24576, 32768, 0];
253
254/// Default CDF for coefficient base EOB.
255pub const DEFAULT_COEFF_BASE_EOB_CDF: [u16; 4] = [10923, 21845, 32768, 0];
256
257/// Number of DC sign contexts.
258pub const DC_SIGN_CTX_COUNT: usize = 3;
259
260/// Default CDF for DC sign.
261pub const DEFAULT_DC_SIGN_CDF: [u16; 3] = [16384, 32768, 0];
262
263/// Number of coefficient base range contexts.
264pub const COEFF_BR_CTX_COUNT: usize = 21;
265
266/// Default CDF for coefficient base range.
267pub const DEFAULT_COEFF_BR_CDF: [u16; 4] = [10923, 21845, 32768, 0];
268
269// =============================================================================
270// Motion Vector CDFs
271// =============================================================================
272
273/// Number of MV joint types.
274pub const MV_JOINTS: usize = 4;
275
276/// Default CDF for MV joint.
277pub const DEFAULT_MV_JOINT_CDF: [u16; 5] = [4096, 11264, 19712, 32768, 0];
278
279/// Number of MV classes.
280pub const MV_CLASSES: usize = 11;
281
282/// Default CDF for MV class.
283pub const DEFAULT_MV_CLASS_CDF: [u16; 12] = [
284    28672, 30976, 31744, 32128, 32320, 32448, 32544, 32608, 32672, 32720, 32768, 0,
285];
286
287/// Default CDF for MV class 0 bit.
288pub const DEFAULT_MV_CLASS0_BIT_CDF: [u16; 3] = [16384, 32768, 0];
289
290/// Number of MV class 0 fractional values.
291pub const MV_CLASS0_FP: usize = 4;
292
293/// Default CDF for MV class 0 fractional.
294pub const DEFAULT_MV_CLASS0_FP_CDF: [[u16; 5]; 2] = [
295    [8192, 16384, 24576, 32768, 0],
296    [8192, 16384, 24576, 32768, 0],
297];
298
299/// Number of MV fractional values.
300pub const MV_FP: usize = 4;
301
302/// Default CDF for MV fractional.
303pub const DEFAULT_MV_FP_CDF: [u16; 5] = [8192, 16384, 24576, 32768, 0];
304
305/// Default CDF for MV class 0 high precision.
306pub const DEFAULT_MV_CLASS0_HP_CDF: [u16; 3] = [16384, 32768, 0];
307
308/// Default CDF for MV high precision.
309pub const DEFAULT_MV_HP_CDF: [u16; 3] = [16384, 32768, 0];
310
311/// Default CDF for MV sign.
312pub const DEFAULT_MV_SIGN_CDF: [u16; 3] = [16384, 32768, 0];
313
314/// Number of MV bits for class > 0.
315pub const MV_OFFSET_BITS: usize = 10;
316
317/// Default CDF for MV bits.
318pub const DEFAULT_MV_BITS_CDF: [[u16; 3]; MV_OFFSET_BITS] = [
319    [16384, 32768, 0],
320    [16384, 32768, 0],
321    [16384, 32768, 0],
322    [16384, 32768, 0],
323    [16384, 32768, 0],
324    [16384, 32768, 0],
325    [16384, 32768, 0],
326    [16384, 32768, 0],
327    [16384, 32768, 0],
328    [16384, 32768, 0],
329];
330
331// =============================================================================
332// Skip CDFs
333// =============================================================================
334
335/// Number of skip contexts.
336pub const SKIP_CONTEXTS: usize = 3;
337
338/// Default CDF for skip.
339pub const DEFAULT_SKIP_CDF: [[u16; 3]; SKIP_CONTEXTS] =
340    [[24576, 32768, 0], [16384, 32768, 0], [8192, 32768, 0]];
341
342// =============================================================================
343// Segment CDFs
344// =============================================================================
345
346/// Maximum number of segments.
347pub const MAX_SEGMENTS: usize = 8;
348
349/// Default CDF for segment ID (tree).
350pub const DEFAULT_SEGMENT_TREE_CDF: [u16; 9] =
351    [4096, 8192, 12288, 16384, 20480, 24576, 28672, 32768, 0];
352
353/// Default CDF for segment ID prediction.
354pub const DEFAULT_SEGMENT_PRED_CDF: [[u16; 3]; 3] =
355    [[16384, 32768, 0], [16384, 32768, 0], [16384, 32768, 0]];
356
357// =============================================================================
358// Reference Frame CDFs
359// =============================================================================
360
361/// Number of reference frame contexts.
362pub const REF_CONTEXTS: usize = 3;
363
364/// Number of reference frame types for single ref.
365pub const SINGLE_REF_TYPES: usize = 7;
366
367/// Default CDF for single reference frame.
368pub const DEFAULT_SINGLE_REF_CDF: [[[u16; 3]; SINGLE_REF_TYPES]; REF_CONTEXTS] = [
369    [
370        [16384, 32768, 0],
371        [16384, 32768, 0],
372        [16384, 32768, 0],
373        [16384, 32768, 0],
374        [16384, 32768, 0],
375        [16384, 32768, 0],
376        [16384, 32768, 0],
377    ],
378    [
379        [16384, 32768, 0],
380        [16384, 32768, 0],
381        [16384, 32768, 0],
382        [16384, 32768, 0],
383        [16384, 32768, 0],
384        [16384, 32768, 0],
385        [16384, 32768, 0],
386    ],
387    [
388        [16384, 32768, 0],
389        [16384, 32768, 0],
390        [16384, 32768, 0],
391        [16384, 32768, 0],
392        [16384, 32768, 0],
393        [16384, 32768, 0],
394        [16384, 32768, 0],
395    ],
396];
397
398// =============================================================================
399// Inter Mode CDFs
400// =============================================================================
401
402/// Number of inter mode contexts.
403pub const INTER_MODE_CONTEXTS: usize = 8;
404
405/// Number of inter modes.
406pub const INTER_MODES: usize = 4;
407
408/// Default CDF for inter mode.
409pub const DEFAULT_INTER_MODE_CDF: [[u16; 5]; INTER_MODE_CONTEXTS] = [
410    [2048, 10240, 17664, 32768, 0],
411    [4096, 12288, 20480, 32768, 0],
412    [6144, 14336, 22528, 32768, 0],
413    [8192, 16384, 24576, 32768, 0],
414    [10240, 18432, 26624, 32768, 0],
415    [12288, 20480, 28672, 32768, 0],
416    [14336, 22528, 29696, 32768, 0],
417    [16384, 24576, 30720, 32768, 0],
418];
419
420// =============================================================================
421// Compound Mode CDFs
422// =============================================================================
423
424/// Number of compound mode contexts.
425pub const COMPOUND_MODE_CONTEXTS: usize = 8;
426
427/// Number of compound modes.
428pub const COMPOUND_MODES: usize = 8;
429
430/// Default CDF for compound mode.
431pub const DEFAULT_COMPOUND_MODE_CDF: [[u16; 9]; COMPOUND_MODE_CONTEXTS] = [
432    [4096, 8192, 12288, 16384, 20480, 24576, 28672, 32768, 0],
433    [4096, 8192, 12288, 16384, 20480, 24576, 28672, 32768, 0],
434    [4096, 8192, 12288, 16384, 20480, 24576, 28672, 32768, 0],
435    [4096, 8192, 12288, 16384, 20480, 24576, 28672, 32768, 0],
436    [4096, 8192, 12288, 16384, 20480, 24576, 28672, 32768, 0],
437    [4096, 8192, 12288, 16384, 20480, 24576, 28672, 32768, 0],
438    [4096, 8192, 12288, 16384, 20480, 24576, 28672, 32768, 0],
439    [4096, 8192, 12288, 16384, 20480, 24576, 28672, 32768, 0],
440];
441
442// =============================================================================
443// Filter CDFs
444// =============================================================================
445
446/// Number of interpolation filter types.
447pub const INTERP_FILTERS: usize = 4;
448
449/// Number of interpolation filter contexts.
450pub const INTERP_FILTER_CONTEXTS: usize = 16;
451
452/// Default CDF for interpolation filter.
453pub const DEFAULT_INTERP_FILTER_CDF: [[u16; 5]; INTERP_FILTER_CONTEXTS] = [
454    [6144, 12288, 18432, 32768, 0],
455    [6144, 12288, 18432, 32768, 0],
456    [6144, 12288, 18432, 32768, 0],
457    [6144, 12288, 18432, 32768, 0],
458    [8192, 16384, 24576, 32768, 0],
459    [8192, 16384, 24576, 32768, 0],
460    [8192, 16384, 24576, 32768, 0],
461    [8192, 16384, 24576, 32768, 0],
462    [10240, 18432, 26624, 32768, 0],
463    [10240, 18432, 26624, 32768, 0],
464    [10240, 18432, 26624, 32768, 0],
465    [10240, 18432, 26624, 32768, 0],
466    [12288, 20480, 28672, 32768, 0],
467    [12288, 20480, 28672, 32768, 0],
468    [12288, 20480, 28672, 32768, 0],
469    [12288, 20480, 28672, 32768, 0],
470];
471
472// =============================================================================
473// CDF Helper Functions
474// =============================================================================
475
476/// Create a uniform CDF for n symbols.
477#[must_use]
478#[allow(clippy::cast_possible_truncation)]
479pub fn create_uniform_cdf(n: usize) -> Vec<u16> {
480    let mut cdf = Vec::with_capacity(n + 1);
481    for i in 1..=n {
482        cdf.push(((i * CDF_PROB_TOP as usize) / n) as u16);
483    }
484    cdf.push(CDF_INIT_COUNT); // Symbol count
485    cdf
486}
487
488/// Update CDF after observing a symbol.
489#[allow(clippy::cast_possible_truncation)]
490pub fn update_cdf(cdf: &mut [u16], symbol: usize) {
491    let n = cdf.len() - 1; // Last element is count
492    if n == 0 {
493        return;
494    }
495
496    // Get current count and compute rate
497    let count = u32::from(cdf[n]);
498    let rate = 3 + (count >> 4);
499    let rate = rate.min(32);
500
501    // Update CDF values
502    for i in 0..n {
503        if i < symbol {
504            // Decrease probability
505            let diff = cdf[i] >> rate;
506            cdf[i] = cdf[i].saturating_sub(diff);
507        } else {
508            // Increase probability
509            let diff = CDF_PROB_TOP.saturating_sub(cdf[i]) >> rate;
510            cdf[i] = cdf[i].saturating_add(diff);
511        }
512    }
513
514    // Increment count
515    if count < u32::from(CDF_MAX_COUNT) {
516        cdf[n] += 1;
517    }
518}
519
520/// Reset CDF to uniform distribution.
521#[allow(clippy::cast_possible_truncation)]
522pub fn reset_cdf_uniform(cdf: &mut [u16]) {
523    let n = cdf.len() - 1;
524    if n == 0 {
525        return;
526    }
527
528    for i in 0..n {
529        cdf[i] = (((i + 1) * CDF_PROB_TOP as usize) / n) as u16;
530    }
531    cdf[n] = CDF_INIT_COUNT;
532}
533
534/// Copy CDF from source to destination.
535pub fn copy_cdf(dst: &mut [u16], src: &[u16]) {
536    let len = dst.len().min(src.len());
537    dst[..len].copy_from_slice(&src[..len]);
538}
539
540/// Check if a CDF is valid (monotonically increasing, ends at CDF_PROB_TOP).
541#[must_use]
542pub fn is_valid_cdf(cdf: &[u16]) -> bool {
543    if cdf.len() < 2 {
544        return false;
545    }
546
547    let n = cdf.len() - 1;
548
549    // Check monotonicity
550    for i in 1..n {
551        if cdf[i] < cdf[i - 1] {
552            return false;
553        }
554    }
555
556    // Last probability should be CDF_PROB_TOP
557    cdf[n - 1] == CDF_PROB_TOP
558}
559
560// =============================================================================
561// CDF Context Management
562// =============================================================================
563
564/// Container for all CDF tables used in decoding.
565#[derive(Clone, Debug)]
566pub struct CdfContext {
567    /// Partition CDFs.
568    pub partition: [[u16; 11]; PARTITION_CONTEXTS],
569    /// Y intra mode CDFs.
570    pub y_mode: [[u16; 14]; INTRA_Y_MODE_CONTEXTS],
571    /// Skip CDFs.
572    pub skip: [[u16; 3]; SKIP_CONTEXTS],
573    /// MV joint CDF.
574    pub mv_joint: [u16; 5],
575    /// MV sign CDFs (for each component).
576    pub mv_sign: [[u16; 3]; 2],
577    /// MV class CDFs.
578    pub mv_class: [[u16; 12]; 2],
579    /// DC sign CDFs.
580    pub dc_sign: [[u16; 3]; DC_SIGN_CTX_COUNT],
581    /// Coefficient base CDFs.
582    pub coeff_base: Vec<[u16; 5]>,
583    /// Coefficient base range CDFs.
584    pub coeff_br: Vec<[u16; 4]>,
585}
586
587impl CdfContext {
588    /// Create a new CDF context with default values.
589    #[must_use]
590    pub fn new() -> Self {
591        Self {
592            partition: DEFAULT_PARTITION_CDFS,
593            y_mode: DEFAULT_Y_MODE_CDFS,
594            skip: DEFAULT_SKIP_CDF,
595            mv_joint: DEFAULT_MV_JOINT_CDF,
596            mv_sign: [DEFAULT_MV_SIGN_CDF, DEFAULT_MV_SIGN_CDF],
597            mv_class: [DEFAULT_MV_CLASS_CDF, DEFAULT_MV_CLASS_CDF],
598            dc_sign: [DEFAULT_DC_SIGN_CDF; DC_SIGN_CTX_COUNT],
599            coeff_base: vec![DEFAULT_COEFF_BASE_CDF; COEFF_BASE_CTX_COUNT],
600            coeff_br: vec![DEFAULT_COEFF_BR_CDF; COEFF_BR_CTX_COUNT],
601        }
602    }
603
604    /// Reset all CDFs to default values.
605    pub fn reset(&mut self) {
606        self.partition = DEFAULT_PARTITION_CDFS;
607        self.y_mode = DEFAULT_Y_MODE_CDFS;
608        self.skip = DEFAULT_SKIP_CDF;
609        self.mv_joint = DEFAULT_MV_JOINT_CDF;
610        self.mv_sign = [DEFAULT_MV_SIGN_CDF, DEFAULT_MV_SIGN_CDF];
611        self.mv_class = [DEFAULT_MV_CLASS_CDF, DEFAULT_MV_CLASS_CDF];
612        self.dc_sign = [DEFAULT_DC_SIGN_CDF; DC_SIGN_CTX_COUNT];
613
614        for cdf in &mut self.coeff_base {
615            *cdf = DEFAULT_COEFF_BASE_CDF;
616        }
617
618        for cdf in &mut self.coeff_br {
619            *cdf = DEFAULT_COEFF_BR_CDF;
620        }
621    }
622
623    /// Get partition CDF for a context.
624    #[must_use]
625    pub fn get_partition_cdf(&self, ctx: usize) -> &[u16; 11] {
626        &self.partition[ctx.min(PARTITION_CONTEXTS - 1)]
627    }
628
629    /// Get mutable partition CDF for a context.
630    pub fn get_partition_cdf_mut(&mut self, ctx: usize) -> &mut [u16; 11] {
631        &mut self.partition[ctx.min(PARTITION_CONTEXTS - 1)]
632    }
633
634    /// Get Y mode CDF for a context.
635    #[must_use]
636    pub fn get_y_mode_cdf(&self, ctx: usize) -> &[u16; 14] {
637        &self.y_mode[ctx.min(INTRA_Y_MODE_CONTEXTS - 1)]
638    }
639
640    /// Get skip CDF for a context.
641    #[must_use]
642    pub fn get_skip_cdf(&self, ctx: usize) -> &[u16; 3] {
643        &self.skip[ctx.min(SKIP_CONTEXTS - 1)]
644    }
645
646    /// Get EOB multi CDF for a context.
647    #[must_use]
648    pub fn get_eob_multi_cdf(&self, _ctx: usize) -> &[u16] {
649        // Stub: Return default uniform CDF
650        &[8192, 16384, 24576, 32768]
651    }
652
653    /// Get coefficient base CDF for a context.
654    #[must_use]
655    pub fn get_coeff_base_cdf(&self, ctx: usize) -> &[u16] {
656        if ctx < self.coeff_base.len() {
657            &self.coeff_base[ctx]
658        } else {
659            &DEFAULT_COEFF_BASE_CDF
660        }
661    }
662
663    /// Get coefficient base EOB CDF for a context.
664    #[must_use]
665    pub fn get_coeff_base_eob_cdf(&self, ctx: usize) -> &[u16] {
666        // For EOB position, use the same as coeff_base
667        self.get_coeff_base_cdf(ctx)
668    }
669
670    /// Get coefficient BR (base range) CDF for a context.
671    #[must_use]
672    pub fn get_coeff_br_cdf(&self, ctx: usize) -> &[u16] {
673        if ctx < self.coeff_br.len() {
674            &self.coeff_br[ctx]
675        } else {
676            &DEFAULT_COEFF_BR_CDF
677        }
678    }
679
680    /// Get DC sign CDF for a context.
681    #[must_use]
682    pub fn get_dc_sign_cdf(&self, ctx: usize) -> &[u16] {
683        &self.dc_sign[ctx.min(DC_SIGN_CTX_COUNT - 1)]
684    }
685
686    // Mutable versions for updating CDFs during decoding
687
688    /// Get mutable EOB multi CDF for a context.
689    pub fn get_eob_multi_cdf_mut(&mut self, _ctx: usize) -> &mut [u16] {
690        // Stub: Return a mutable slice (can't return static, so use first coeff_base)
691        if !self.coeff_base.is_empty() {
692            let len = self.coeff_base[0].len();
693            let max_len = 4.min(len);
694            if max_len > 0 {
695                return &mut self.coeff_base[0][..max_len];
696            }
697        }
698        // Return an empty slice from dc_sign as fallback
699        if !self.dc_sign.is_empty() {
700            &mut self.dc_sign[0][..0]
701        } else {
702            &mut []
703        }
704    }
705
706    /// Get mutable coefficient base CDF for a context.
707    pub fn get_coeff_base_cdf_mut(&mut self, ctx: usize) -> &mut [u16] {
708        if ctx < self.coeff_base.len() {
709            &mut self.coeff_base[ctx]
710        } else if !self.coeff_base.is_empty() {
711            &mut self.coeff_base[0]
712        } else {
713            &mut []
714        }
715    }
716
717    /// Get mutable coefficient base EOB CDF for a context.
718    pub fn get_coeff_base_eob_cdf_mut(&mut self, ctx: usize) -> &mut [u16] {
719        self.get_coeff_base_cdf_mut(ctx)
720    }
721
722    /// Get mutable coefficient BR (base range) CDF for a context.
723    pub fn get_coeff_br_cdf_mut(&mut self, ctx: usize) -> &mut [u16] {
724        if ctx < self.coeff_br.len() {
725            &mut self.coeff_br[ctx]
726        } else if !self.coeff_br.is_empty() {
727            &mut self.coeff_br[0]
728        } else {
729            &mut []
730        }
731    }
732
733    /// Get mutable DC sign CDF for a context.
734    pub fn get_dc_sign_cdf_mut(&mut self, ctx: usize) -> &mut [u16] {
735        let idx = ctx.min(DC_SIGN_CTX_COUNT - 1);
736        &mut self.dc_sign[idx]
737    }
738}
739
740impl Default for CdfContext {
741    fn default() -> Self {
742        Self::new()
743    }
744}
745
746// =============================================================================
747// Tests
748// =============================================================================
749
750#[cfg(test)]
751mod tests {
752    use super::*;
753
754    #[test]
755    fn test_create_uniform_cdf() {
756        let cdf = create_uniform_cdf(4);
757        assert_eq!(cdf.len(), 5);
758        assert_eq!(cdf[0], 8192);
759        assert_eq!(cdf[1], 16384);
760        assert_eq!(cdf[2], 24576);
761        assert_eq!(cdf[3], 32768);
762        assert_eq!(cdf[4], 0); // Count
763    }
764
765    #[test]
766    fn test_update_cdf() {
767        let mut cdf = create_uniform_cdf(4);
768
769        // Update with symbol 0 should increase its probability
770        let orig_0 = cdf[0];
771        update_cdf(&mut cdf, 0);
772        assert!(cdf[0] >= orig_0);
773    }
774
775    #[test]
776    fn test_reset_cdf_uniform() {
777        let mut cdf = vec![0u16; 5];
778        reset_cdf_uniform(&mut cdf);
779
780        assert_eq!(cdf[0], 8192);
781        assert_eq!(cdf[3], 32768);
782        assert_eq!(cdf[4], 0);
783    }
784
785    #[test]
786    fn test_copy_cdf() {
787        let src = create_uniform_cdf(4);
788        let mut dst = vec![0u16; 5];
789
790        copy_cdf(&mut dst, &src);
791        assert_eq!(dst, src);
792    }
793
794    #[test]
795    fn test_is_valid_cdf() {
796        let valid_cdf = create_uniform_cdf(4);
797        assert!(is_valid_cdf(&valid_cdf));
798
799        let invalid_cdf = vec![100u16, 50, 200, 32768, 0]; // Not monotonic
800        assert!(!is_valid_cdf(&invalid_cdf));
801    }
802
803    #[test]
804    fn test_cdf_context_new() {
805        let ctx = CdfContext::new();
806        assert_eq!(ctx.partition.len(), PARTITION_CONTEXTS);
807        assert_eq!(ctx.y_mode.len(), INTRA_Y_MODE_CONTEXTS);
808    }
809
810    #[test]
811    fn test_cdf_context_reset() {
812        let mut ctx = CdfContext::new();
813        ctx.partition[0][0] = 12345;
814
815        ctx.reset();
816        assert_eq!(ctx.partition[0], DEFAULT_PARTITION_CDFS[0]);
817    }
818
819    #[test]
820    fn test_get_partition_cdf() {
821        let ctx = CdfContext::new();
822
823        let cdf = ctx.get_partition_cdf(0);
824        assert_eq!(cdf, &DEFAULT_PARTITION_CDF_0);
825
826        // Out of bounds should clamp
827        let cdf_clamped = ctx.get_partition_cdf(100);
828        assert_eq!(cdf_clamped, &DEFAULT_PARTITION_CDF_3);
829    }
830
831    #[test]
832    fn test_default_cdfs_valid() {
833        // Check that default CDFs are valid
834        for cdf in &DEFAULT_PARTITION_CDFS {
835            assert!(is_valid_cdf(cdf));
836        }
837
838        for cdf in &DEFAULT_Y_MODE_CDFS {
839            assert!(is_valid_cdf(cdf));
840        }
841
842        assert!(is_valid_cdf(&DEFAULT_MV_JOINT_CDF));
843        assert!(is_valid_cdf(&DEFAULT_MV_CLASS_CDF));
844    }
845
846    #[test]
847    fn test_cdf_constants() {
848        assert_eq!(CDF_PROB_BITS, 15);
849        assert_eq!(CDF_PROB_TOP, 32768);
850        assert_eq!(CDF_MAX_COUNT, 32);
851    }
852
853    #[test]
854    fn test_partition_contexts() {
855        assert_eq!(PARTITION_CONTEXTS, 4);
856        assert_eq!(PARTITION_TYPES, 10);
857    }
858
859    #[test]
860    fn test_intra_mode_contexts() {
861        assert_eq!(INTRA_MODES, 13);
862        assert_eq!(INTRA_Y_MODE_CONTEXTS, 4);
863    }
864
865    #[test]
866    fn test_mv_constants() {
867        assert_eq!(MV_JOINTS, 4);
868        assert_eq!(MV_CLASSES, 11);
869        assert_eq!(MV_OFFSET_BITS, 10);
870    }
871}