rav1e/context/
transform_unit.rs

1// Copyright (c) 2017-2022, The rav1e contributors. All rights reserved
2//
3// This source code is subject to the terms of the BSD 2 Clause License and
4// the Alliance for Open Media Patent License 1.0. If the BSD 2 Clause License
5// was not distributed with this source code in the LICENSE file, you can
6// obtain it at www.aomedia.org/license/software. If the Alliance for Open
7// Media Patent License 1.0 was not distributed with this source code in the
8// PATENTS file, you can obtain it at www.aomedia.org/license/patent.
9
10use super::*;
11use crate::predict::PredictionMode;
12use crate::predict::PredictionMode::*;
13use crate::transform::TxType::*;
14use std::mem::MaybeUninit;
15
16pub const MAX_TX_SIZE: usize = 64;
17
18pub const MAX_CODED_TX_SIZE: usize = 32;
19pub const MAX_CODED_TX_SQUARE: usize = MAX_CODED_TX_SIZE * MAX_CODED_TX_SIZE;
20
21pub const TX_SIZE_SQR_CONTEXTS: usize = 4; // Coded tx_size <= 32x32, so is the # of CDF contexts from tx sizes
22
23pub const TX_SETS: usize = 6;
24pub const TX_SETS_INTRA: usize = 3;
25pub const TX_SETS_INTER: usize = 4;
26
27pub const INTRA_MODES: usize = 13;
28pub const UV_INTRA_MODES: usize = 14;
29
30const MAX_VARTX_DEPTH: usize = 2;
31
32pub const TXFM_PARTITION_CONTEXTS: usize =
33  (TxSize::TX_SIZES - TxSize::TX_8X8 as usize) * 6 - 3;
34
35// Number of transform types in each set type
36pub static num_tx_set: [usize; TX_SETS] = [1, 2, 5, 7, 12, 16];
37pub static av1_tx_used: [[usize; TX_TYPES]; TX_SETS] = [
38  [1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0],
39  [1, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0, 0, 0, 0],
40  [1, 1, 1, 1, 0, 0, 0, 0, 0, 1, 0, 0, 0, 0, 0, 0],
41  [1, 1, 1, 1, 0, 0, 0, 0, 0, 1, 1, 1, 0, 0, 0, 0],
42  [1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 0, 0, 0, 0],
43  [1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1],
44];
45
46// Maps set types above to the indices used for intra
47static tx_set_index_intra: [i8; TX_SETS] = [0, -1, 2, 1, -1, -1];
48// Maps set types above to the indices used for inter
49static tx_set_index_inter: [i8; TX_SETS] = [0, 3, -1, -1, 2, 1];
50
51pub static av1_tx_ind: [[usize; TX_TYPES]; TX_SETS] = [
52  [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0],
53  [1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0],
54  [1, 3, 4, 2, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0],
55  [1, 5, 6, 4, 0, 0, 0, 0, 0, 0, 2, 3, 0, 0, 0, 0],
56  [3, 4, 5, 8, 6, 7, 9, 10, 11, 0, 1, 2, 0, 0, 0, 0],
57  [7, 8, 9, 12, 10, 11, 13, 14, 15, 0, 1, 2, 3, 4, 5, 6],
58];
59
60pub static max_txsize_rect_lookup: [TxSize; BlockSize::BLOCK_SIZES_ALL] = [
61  TX_4X4,   // 4x4
62  TX_4X8,   // 4x8
63  TX_8X4,   // 8x4
64  TX_8X8,   // 8x8
65  TX_8X16,  // 8x16
66  TX_16X8,  // 16x8
67  TX_16X16, // 16x16
68  TX_16X32, // 16x32
69  TX_32X16, // 32x16
70  TX_32X32, // 32x32
71  TX_32X64, // 32x64
72  TX_64X32, // 64x32
73  TX_64X64, // 64x64
74  TX_64X64, // 64x128
75  TX_64X64, // 128x64
76  TX_64X64, // 128x128
77  TX_4X16,  // 4x16
78  TX_16X4,  // 16x4
79  TX_8X32,  // 8x32
80  TX_32X8,  // 32x8
81  TX_16X64, // 16x64
82  TX_64X16, // 64x16
83];
84
85pub static sub_tx_size_map: [TxSize; TxSize::TX_SIZES_ALL] = [
86  TX_4X4,   // TX_4X4
87  TX_4X4,   // TX_8X8
88  TX_8X8,   // TX_16X16
89  TX_16X16, // TX_32X32
90  TX_32X32, // TX_64X64
91  TX_4X4,   // TX_4X8
92  TX_4X4,   // TX_8X4
93  TX_8X8,   // TX_8X16
94  TX_8X8,   // TX_16X8
95  TX_16X16, // TX_16X32
96  TX_16X16, // TX_32X16
97  TX_32X32, // TX_32X64
98  TX_32X32, // TX_64X32
99  TX_4X8,   // TX_4X16
100  TX_8X4,   // TX_16X4
101  TX_8X16,  // TX_8X32
102  TX_16X8,  // TX_32X8
103  TX_16X32, // TX_16X64
104  TX_32X16, // TX_64X16
105];
106
107#[inline]
108pub fn has_chroma(
109  bo: TileBlockOffset, bsize: BlockSize, subsampling_x: usize,
110  subsampling_y: usize, chroma_sampling: ChromaSampling,
111) -> bool {
112  if chroma_sampling == ChromaSampling::Cs400 {
113    return false;
114  };
115
116  let bw = bsize.width_mi();
117  let bh = bsize.height_mi();
118
119  ((bo.0.x & 0x01) == 1 || (bw & 0x01) == 0 || subsampling_x == 0)
120    && ((bo.0.y & 0x01) == 1 || (bh & 0x01) == 0 || subsampling_y == 0)
121}
122
123pub fn get_tx_set(
124  tx_size: TxSize, is_inter: bool, use_reduced_set: bool,
125) -> TxSet {
126  let tx_size_sqr_up = tx_size.sqr_up();
127  let tx_size_sqr = tx_size.sqr();
128
129  if tx_size_sqr_up.block_size() > BlockSize::BLOCK_32X32 {
130    return TxSet::TX_SET_DCTONLY;
131  }
132
133  if is_inter {
134    if use_reduced_set || tx_size_sqr_up == TxSize::TX_32X32 {
135      TxSet::TX_SET_INTER_3
136    } else if tx_size_sqr == TxSize::TX_16X16 {
137      TxSet::TX_SET_INTER_2
138    } else {
139      TxSet::TX_SET_INTER_1
140    }
141  } else if tx_size_sqr_up == TxSize::TX_32X32 {
142    TxSet::TX_SET_DCTONLY
143  } else if use_reduced_set || tx_size_sqr == TxSize::TX_16X16 {
144    TxSet::TX_SET_INTRA_2
145  } else {
146    TxSet::TX_SET_INTRA_1
147  }
148}
149
150pub fn get_tx_set_index(
151  tx_size: TxSize, is_inter: bool, use_reduced_set: bool,
152) -> i8 {
153  let set_type = get_tx_set(tx_size, is_inter, use_reduced_set);
154
155  if is_inter {
156    tx_set_index_inter[set_type as usize]
157  } else {
158    tx_set_index_intra[set_type as usize]
159  }
160}
161
162static intra_mode_to_tx_type_context: [TxType; INTRA_MODES] = [
163  DCT_DCT,   // DC
164  ADST_DCT,  // V
165  DCT_ADST,  // H
166  DCT_DCT,   // D45
167  ADST_ADST, // D135
168  ADST_DCT,  // D113
169  DCT_ADST,  // D157
170  DCT_ADST,  // D203
171  ADST_DCT,  // D67
172  ADST_ADST, // SMOOTH
173  ADST_DCT,  // SMOOTH_V
174  DCT_ADST,  // SMOOTH_H
175  ADST_ADST, // PAETH
176];
177
178static uv2y: [PredictionMode; UV_INTRA_MODES] = [
179  DC_PRED,       // UV_DC_PRED
180  V_PRED,        // UV_V_PRED
181  H_PRED,        // UV_H_PRED
182  D45_PRED,      // UV_D45_PRED
183  D135_PRED,     // UV_D135_PRED
184  D113_PRED,     // UV_D113_PRED
185  D157_PRED,     // UV_D157_PRED
186  D203_PRED,     // UV_D203_PRED
187  D67_PRED,      // UV_D67_PRED
188  SMOOTH_PRED,   // UV_SMOOTH_PRED
189  SMOOTH_V_PRED, // UV_SMOOTH_V_PRED
190  SMOOTH_H_PRED, // UV_SMOOTH_H_PRED
191  PAETH_PRED,    // UV_PAETH_PRED
192  DC_PRED,       // CFL_PRED
193];
194
195pub fn uv_intra_mode_to_tx_type_context(pred: PredictionMode) -> TxType {
196  intra_mode_to_tx_type_context[uv2y[pred as usize] as usize]
197}
198
199// Level Map
200pub const TXB_SKIP_CONTEXTS: usize = 13;
201
202pub const EOB_COEF_CONTEXTS: usize = 9;
203
204const SIG_COEF_CONTEXTS_2D: usize = 26;
205const SIG_COEF_CONTEXTS_1D: usize = 16;
206pub const SIG_COEF_CONTEXTS_EOB: usize = 4;
207pub const SIG_COEF_CONTEXTS: usize =
208  SIG_COEF_CONTEXTS_2D + SIG_COEF_CONTEXTS_1D;
209
210const COEFF_BASE_CONTEXTS: usize = SIG_COEF_CONTEXTS;
211pub const DC_SIGN_CONTEXTS: usize = 3;
212
213const BR_TMP_OFFSET: usize = 12;
214const BR_REF_CAT: usize = 4;
215pub const LEVEL_CONTEXTS: usize = 21;
216
217pub const NUM_BASE_LEVELS: usize = 2;
218
219pub const BR_CDF_SIZE: usize = 4;
220pub const COEFF_BASE_RANGE: usize = 4 * (BR_CDF_SIZE - 1);
221
222pub const COEFF_CONTEXT_BITS: usize = 6;
223pub const COEFF_CONTEXT_MASK: usize = (1 << COEFF_CONTEXT_BITS) - 1;
224const MAX_BASE_BR_RANGE: usize = COEFF_BASE_RANGE + NUM_BASE_LEVELS + 1;
225
226const BASE_CONTEXT_POSITION_NUM: usize = 12;
227
228// Pad 4 extra columns to remove horizontal availability check.
229pub const TX_PAD_HOR_LOG2: usize = 2;
230pub const TX_PAD_HOR: usize = 4;
231// Pad 6 extra rows (2 on top and 4 on bottom) to remove vertical availability
232// check.
233pub const TX_PAD_TOP: usize = 2;
234pub const TX_PAD_BOTTOM: usize = 4;
235pub const TX_PAD_VER: usize = TX_PAD_TOP + TX_PAD_BOTTOM;
236// Pad 16 extra bytes to avoid reading overflow in SIMD optimization.
237const TX_PAD_END: usize = 16;
238pub const TX_PAD_2D: usize = (MAX_CODED_TX_SIZE + TX_PAD_HOR)
239  * (MAX_CODED_TX_SIZE + TX_PAD_VER)
240  + TX_PAD_END;
241
242const TX_CLASSES: usize = 3;
243
244#[derive(Copy, Clone, PartialEq, Eq)]
245pub enum TxClass {
246  TX_CLASS_2D = 0,
247  TX_CLASS_HORIZ = 1,
248  TX_CLASS_VERT = 2,
249}
250
251#[derive(Copy, Clone, PartialEq, Eq)]
252pub enum SegLvl {
253  SEG_LVL_ALT_Q = 0,      /* Use alternate Quantizer .... */
254  SEG_LVL_ALT_LF_Y_V = 1, /* Use alternate loop filter value on y plane vertical */
255  SEG_LVL_ALT_LF_Y_H = 2, /* Use alternate loop filter value on y plane horizontal */
256  SEG_LVL_ALT_LF_U = 3,   /* Use alternate loop filter value on u plane */
257  SEG_LVL_ALT_LF_V = 4,   /* Use alternate loop filter value on v plane */
258  SEG_LVL_REF_FRAME = 5,  /* Optional Segment reference frame */
259  SEG_LVL_SKIP = 6,       /* Optional Segment (0,0) + skip mode */
260  SEG_LVL_GLOBALMV = 7,
261  SEG_LVL_MAX = 8,
262}
263
264pub const seg_feature_bits: [u32; SegLvl::SEG_LVL_MAX as usize] =
265  [8, 6, 6, 6, 6, 3, 0, 0];
266
267pub const seg_feature_is_signed: [bool; SegLvl::SEG_LVL_MAX as usize] =
268  [true, true, true, true, true, false, false, false];
269
270use crate::context::TxClass::*;
271
272pub static tx_type_to_class: [TxClass; TX_TYPES] = [
273  TX_CLASS_2D,    // DCT_DCT
274  TX_CLASS_2D,    // ADST_DCT
275  TX_CLASS_2D,    // DCT_ADST
276  TX_CLASS_2D,    // ADST_ADST
277  TX_CLASS_2D,    // FLIPADST_DCT
278  TX_CLASS_2D,    // DCT_FLIPADST
279  TX_CLASS_2D,    // FLIPADST_FLIPADST
280  TX_CLASS_2D,    // ADST_FLIPADST
281  TX_CLASS_2D,    // FLIPADST_ADST
282  TX_CLASS_2D,    // IDTX
283  TX_CLASS_VERT,  // V_DCT
284  TX_CLASS_HORIZ, // H_DCT
285  TX_CLASS_VERT,  // V_ADST
286  TX_CLASS_HORIZ, // H_ADST
287  TX_CLASS_VERT,  // V_FLIPADST
288  TX_CLASS_HORIZ, // H_FLIPADST
289];
290
291pub static eob_to_pos_small: [u8; 33] = [
292  0, 1, 2, // 0-2
293  3, 3, // 3-4
294  4, 4, 4, 4, // 5-8
295  5, 5, 5, 5, 5, 5, 5, 5, // 9-16
296  6, 6, 6, 6, 6, 6, 6, 6, 6, 6, 6, 6, 6, 6, 6, 6, // 17-32
297];
298
299pub static eob_to_pos_large: [u8; 17] = [
300  6, // place holder
301  7, // 33-64
302  8, 8, // 65-128
303  9, 9, 9, 9, // 129-256
304  10, 10, 10, 10, 10, 10, 10, 10, // 257-512
305  11, // 513-
306];
307
308pub static k_eob_group_start: [u16; 12] =
309  [0, 1, 2, 3, 5, 9, 17, 33, 65, 129, 257, 513];
310pub static k_eob_offset_bits: [u16; 12] = [0, 0, 0, 1, 2, 3, 4, 5, 6, 7, 8, 9];
311
312// The ctx offset table when TX is TX_CLASS_2D.
313// TX col and row indices are clamped to 4
314
315#[rustfmt::skip]
316pub static av1_nz_map_ctx_offset: [[[i8; 5]; 5]; TxSize::TX_SIZES_ALL] = [
317  // TX_4X4
318  [
319    [ 0,  1,  6,  6, 0],
320    [ 1,  6,  6, 21, 0],
321    [ 6,  6, 21, 21, 0],
322    [ 6, 21, 21, 21, 0],
323    [ 0,  0,  0,  0, 0]
324  ],
325  // TX_8X8
326  [
327    [ 0,  1,  6,  6, 21],
328    [ 1,  6,  6, 21, 21],
329    [ 6,  6, 21, 21, 21],
330    [ 6, 21, 21, 21, 21],
331    [21, 21, 21, 21, 21]
332  ],
333  // TX_16X16
334  [
335    [ 0,  1,  6,  6, 21],
336    [ 1,  6,  6, 21, 21],
337    [ 6,  6, 21, 21, 21],
338    [ 6, 21, 21, 21, 21],
339    [21, 21, 21, 21, 21]
340  ],
341  // TX_32X32
342  [
343    [ 0,  1,  6,  6, 21],
344    [ 1,  6,  6, 21, 21],
345    [ 6,  6, 21, 21, 21],
346    [ 6, 21, 21, 21, 21],
347    [21, 21, 21, 21, 21]
348  ],
349  // TX_64X64
350  [
351    [ 0,  1,  6,  6, 21],
352    [ 1,  6,  6, 21, 21],
353    [ 6,  6, 21, 21, 21],
354    [ 6, 21, 21, 21, 21],
355    [21, 21, 21, 21, 21]
356  ],
357  // TX_4X8
358  [
359    [ 0, 11, 11, 11, 0],
360    [11, 11, 11, 11, 0],
361    [ 6,  6, 21, 21, 0],
362    [ 6, 21, 21, 21, 0],
363    [21, 21, 21, 21, 0]
364  ],
365  // TX_8X4
366  [
367    [ 0, 16,  6,  6, 21],
368    [16, 16,  6, 21, 21],
369    [16, 16, 21, 21, 21],
370    [16, 16, 21, 21, 21],
371    [ 0,  0,  0,  0, 0]
372  ],
373  // TX_8X16
374  [
375    [ 0, 11, 11, 11, 11],
376    [11, 11, 11, 11, 11],
377    [ 6,  6, 21, 21, 21],
378    [ 6, 21, 21, 21, 21],
379    [21, 21, 21, 21, 21]
380  ],
381  // TX_16X8
382  [
383    [ 0, 16,  6,  6, 21],
384    [16, 16,  6, 21, 21],
385    [16, 16, 21, 21, 21],
386    [16, 16, 21, 21, 21],
387    [16, 16, 21, 21, 21]
388  ],
389  // TX_16X32
390  [
391    [ 0, 11, 11, 11, 11],
392    [11, 11, 11, 11, 11],
393    [ 6,  6, 21, 21, 21],
394    [ 6, 21, 21, 21, 21],
395    [21, 21, 21, 21, 21]
396  ],
397  // TX_32X16
398  [
399    [ 0, 16,  6,  6, 21],
400    [16, 16,  6, 21, 21],
401    [16, 16, 21, 21, 21],
402    [16, 16, 21, 21, 21],
403    [16, 16, 21, 21, 21]
404  ],
405  // TX_32X64
406  [
407    [ 0, 11, 11, 11, 11],
408    [11, 11, 11, 11, 11],
409    [ 6,  6, 21, 21, 21],
410    [ 6, 21, 21, 21, 21],
411    [21, 21, 21, 21, 21]
412  ],
413  // TX_64X32
414  [
415    [ 0, 16,  6,  6, 21],
416    [16, 16,  6, 21, 21],
417    [16, 16, 21, 21, 21],
418    [16, 16, 21, 21, 21],
419    [16, 16, 21, 21, 21]
420  ],
421  // TX_4X16
422  [
423    [ 0, 11, 11, 11, 0],
424    [11, 11, 11, 11, 0],
425    [ 6,  6, 21, 21, 0],
426    [ 6, 21, 21, 21, 0],
427    [21, 21, 21, 21, 0]
428  ],
429  // TX_16X4
430  [
431    [ 0, 16,  6,  6, 21],
432    [16, 16,  6, 21, 21],
433    [16, 16, 21, 21, 21],
434    [16, 16, 21, 21, 21],
435    [ 0,  0,  0,  0, 0]
436  ],
437  // TX_8X32
438  [
439    [ 0, 11, 11, 11, 11],
440    [11, 11, 11, 11, 11],
441    [ 6,  6, 21, 21, 21],
442    [ 6, 21, 21, 21, 21],
443    [21, 21, 21, 21, 21]
444  ],
445  // TX_32X8
446  [
447    [ 0, 16,  6,  6, 21],
448    [16, 16,  6, 21, 21],
449    [16, 16, 21, 21, 21],
450    [16, 16, 21, 21, 21],
451    [16, 16, 21, 21, 21]
452  ],
453  // TX_16X64
454  [
455    [ 0, 11, 11, 11, 11],
456    [11, 11, 11, 11, 11],
457    [ 6,  6, 21, 21, 21],
458    [ 6, 21, 21, 21, 21],
459    [21, 21, 21, 21, 21]
460  ],
461  // TX_64X16
462  [
463    [ 0, 16,  6,  6, 21],
464    [16, 16,  6, 21, 21],
465    [16, 16, 21, 21, 21],
466    [16, 16, 21, 21, 21],
467    [16, 16, 21, 21, 21]
468  ]
469];
470
471const NZ_MAP_CTX_0: usize = SIG_COEF_CONTEXTS_2D;
472const NZ_MAP_CTX_5: usize = NZ_MAP_CTX_0 + 5;
473const NZ_MAP_CTX_10: usize = NZ_MAP_CTX_0 + 10;
474
475pub static nz_map_ctx_offset_1d: [usize; 32] = [
476  NZ_MAP_CTX_0,
477  NZ_MAP_CTX_5,
478  NZ_MAP_CTX_10,
479  NZ_MAP_CTX_10,
480  NZ_MAP_CTX_10,
481  NZ_MAP_CTX_10,
482  NZ_MAP_CTX_10,
483  NZ_MAP_CTX_10,
484  NZ_MAP_CTX_10,
485  NZ_MAP_CTX_10,
486  NZ_MAP_CTX_10,
487  NZ_MAP_CTX_10,
488  NZ_MAP_CTX_10,
489  NZ_MAP_CTX_10,
490  NZ_MAP_CTX_10,
491  NZ_MAP_CTX_10,
492  NZ_MAP_CTX_10,
493  NZ_MAP_CTX_10,
494  NZ_MAP_CTX_10,
495  NZ_MAP_CTX_10,
496  NZ_MAP_CTX_10,
497  NZ_MAP_CTX_10,
498  NZ_MAP_CTX_10,
499  NZ_MAP_CTX_10,
500  NZ_MAP_CTX_10,
501  NZ_MAP_CTX_10,
502  NZ_MAP_CTX_10,
503  NZ_MAP_CTX_10,
504  NZ_MAP_CTX_10,
505  NZ_MAP_CTX_10,
506  NZ_MAP_CTX_10,
507  NZ_MAP_CTX_10,
508];
509
510const CONTEXT_MAG_POSITION_NUM: usize = 3;
511
512static mag_ref_offset_with_txclass: [[[usize; 2]; CONTEXT_MAG_POSITION_NUM];
513  3] = [
514  [[0, 1], [1, 0], [1, 1]],
515  [[0, 1], [1, 0], [0, 2]],
516  [[0, 1], [1, 0], [2, 0]],
517];
518
519// End of Level Map
520
521pub struct TXB_CTX {
522  pub txb_skip_ctx: usize,
523  pub dc_sign_ctx: usize,
524}
525
526impl ContextWriter<'_> {
527  /// # Panics
528  ///
529  /// - If an invalid combination of `tx_type` and `tx_size` is passed
530  pub fn write_tx_type<W: Writer>(
531    &mut self, w: &mut W, tx_size: TxSize, tx_type: TxType,
532    y_mode: PredictionMode, is_inter: bool, use_reduced_tx_set: bool,
533  ) {
534    let square_tx_size = tx_size.sqr();
535    let tx_set = get_tx_set(tx_size, is_inter, use_reduced_tx_set);
536    let num_tx_types = num_tx_set[tx_set as usize];
537
538    if num_tx_types > 1 {
539      let tx_set_index =
540        get_tx_set_index(tx_size, is_inter, use_reduced_tx_set);
541      assert!(tx_set_index > 0);
542      assert!(av1_tx_used[tx_set as usize][tx_type as usize] != 0);
543
544      if is_inter {
545        let s = av1_tx_ind[tx_set as usize][tx_type as usize] as u32;
546        if tx_set_index == 1 {
547          let cdf = &self.fc.inter_tx_1_cdf[square_tx_size as usize];
548          symbol_with_update!(self, w, s, cdf);
549        } else if tx_set_index == 2 {
550          let cdf = &self.fc.inter_tx_2_cdf[square_tx_size as usize];
551          symbol_with_update!(self, w, s, cdf);
552        } else {
553          let cdf = &self.fc.inter_tx_3_cdf[square_tx_size as usize];
554          symbol_with_update!(self, w, s, cdf);
555        }
556      } else {
557        let intra_dir = y_mode;
558        // TODO: Once use_filter_intra is enabled,
559        // intra_dir =
560        // fimode_to_intradir[mbmi->filter_intra_mode_info.filter_intra_mode];
561
562        let s = av1_tx_ind[tx_set as usize][tx_type as usize] as u32;
563        if tx_set_index == 1 {
564          let cdf = &self.fc.intra_tx_1_cdf[square_tx_size as usize]
565            [intra_dir as usize];
566          symbol_with_update!(self, w, s, cdf);
567        } else {
568          let cdf = &self.fc.intra_tx_2_cdf[square_tx_size as usize]
569            [intra_dir as usize];
570          symbol_with_update!(self, w, s, cdf);
571        }
572      }
573    }
574  }
575
576  fn get_tx_size_context(
577    &self, bo: TileBlockOffset, bsize: BlockSize,
578  ) -> usize {
579    let max_tx_size = max_txsize_rect_lookup[bsize as usize];
580    let max_tx_wide = max_tx_size.width() as u8;
581    let max_tx_high = max_tx_size.height() as u8;
582    let has_above = bo.0.y > 0;
583    let has_left = bo.0.x > 0;
584    let mut above = self.bc.above_tx_context[bo.0.x] >= max_tx_wide;
585    let mut left = self.bc.left_tx_context[bo.y_in_sb()] >= max_tx_high;
586
587    if has_above {
588      let above_blk = self.bc.blocks.above_of(bo);
589      if above_blk.is_inter() {
590        above = (above_blk.n4_w << MI_SIZE_LOG2) >= max_tx_wide;
591      };
592    }
593    if has_left {
594      let left_blk = self.bc.blocks.left_of(bo);
595      if left_blk.is_inter() {
596        left = (left_blk.n4_h << MI_SIZE_LOG2) >= max_tx_high;
597      };
598    }
599    if has_above && has_left {
600      return above as usize + left as usize;
601    };
602    if has_above {
603      return above as usize;
604    };
605    if has_left {
606      return left as usize;
607    };
608    0
609  }
610
611  pub fn write_tx_size_intra<W: Writer>(
612    &mut self, w: &mut W, bo: TileBlockOffset, bsize: BlockSize,
613    tx_size: TxSize,
614  ) {
615    fn tx_size_to_depth(tx_size: TxSize, bsize: BlockSize) -> usize {
616      let mut ctx_size = max_txsize_rect_lookup[bsize as usize];
617      let mut depth: usize = 0;
618      while tx_size != ctx_size {
619        depth += 1;
620        ctx_size = sub_tx_size_map[ctx_size as usize];
621        debug_assert!(depth <= MAX_TX_DEPTH);
622      }
623      depth
624    }
625    fn bsize_to_max_depth(bsize: BlockSize) -> usize {
626      let mut tx_size: TxSize = max_txsize_rect_lookup[bsize as usize];
627      let mut depth = 0;
628      while depth < MAX_TX_DEPTH && tx_size != TX_4X4 {
629        depth += 1;
630        tx_size = sub_tx_size_map[tx_size as usize];
631        debug_assert!(depth <= MAX_TX_DEPTH);
632      }
633      depth
634    }
635    fn bsize_to_tx_size_cat(bsize: BlockSize) -> usize {
636      let mut tx_size: TxSize = max_txsize_rect_lookup[bsize as usize];
637      debug_assert!(tx_size != TX_4X4);
638      let mut depth = 0;
639      while tx_size != TX_4X4 {
640        depth += 1;
641        tx_size = sub_tx_size_map[tx_size as usize];
642      }
643      debug_assert!(depth <= MAX_TX_CATS);
644
645      depth - 1
646    }
647
648    debug_assert!(!self.bc.blocks[bo].is_inter());
649    debug_assert!(bsize > BlockSize::BLOCK_4X4);
650
651    let tx_size_ctx = self.get_tx_size_context(bo, bsize);
652    let depth = tx_size_to_depth(tx_size, bsize);
653
654    let max_depths = bsize_to_max_depth(bsize);
655    let tx_size_cat = bsize_to_tx_size_cat(bsize);
656
657    debug_assert!(depth <= max_depths);
658    debug_assert!(!tx_size.is_rect() || bsize.is_rect_tx_allowed());
659
660    if tx_size_cat > 0 {
661      let cdf = &self.fc.tx_size_cdf[tx_size_cat - 1][tx_size_ctx];
662      symbol_with_update!(self, w, depth as u32, cdf);
663    } else {
664      let cdf = &self.fc.tx_size_8x8_cdf[tx_size_ctx];
665      symbol_with_update!(self, w, depth as u32, cdf);
666    }
667  }
668
669  // Based on https://aomediacodec.github.io/av1-spec/#cdf-selection-process
670  // Used to decide the cdf (context) for txfm_split
671  fn get_above_tx_width(
672    &self, bo: TileBlockOffset, _bsize: BlockSize, _tx_size: TxSize,
673    first_tx: bool,
674  ) -> usize {
675    let has_above = bo.0.y > 0;
676    if first_tx {
677      if !has_above {
678        return 64;
679      }
680      let above_blk = self.bc.blocks.above_of(bo);
681      if above_blk.skip && above_blk.is_inter() {
682        return above_blk.bsize.width();
683      }
684    }
685    self.bc.above_tx_context[bo.0.x] as usize
686  }
687
688  fn get_left_tx_height(
689    &self, bo: TileBlockOffset, _bsize: BlockSize, _tx_size: TxSize,
690    first_tx: bool,
691  ) -> usize {
692    let has_left = bo.0.x > 0;
693    if first_tx {
694      if !has_left {
695        return 64;
696      }
697      let left_blk = self.bc.blocks.left_of(bo);
698      if left_blk.skip && left_blk.is_inter() {
699        return left_blk.bsize.height();
700      }
701    }
702    self.bc.left_tx_context[bo.y_in_sb()] as usize
703  }
704
705  fn txfm_partition_context(
706    &self, bo: TileBlockOffset, bsize: BlockSize, tx_size: TxSize, tbx: usize,
707    tby: usize,
708  ) -> usize {
709    debug_assert!(tx_size > TX_4X4);
710    debug_assert!(bsize > BlockSize::BLOCK_4X4);
711
712    // TODO: from 2nd level partition, must know whether the tx block is the topmost(or leftmost) within a partition
713    let above = (self.get_above_tx_width(bo, bsize, tx_size, tby == 0)
714      < tx_size.width()) as usize;
715    let left = (self.get_left_tx_height(bo, bsize, tx_size, tbx == 0)
716      < tx_size.height()) as usize;
717
718    let max_tx_size: TxSize = bsize.tx_size().sqr_up();
719    let category: usize = (tx_size.sqr_up() != max_tx_size) as usize
720      + (TxSize::TX_SIZES - 1 - max_tx_size as usize) * 2;
721
722    debug_assert!(category < TXFM_PARTITION_CONTEXTS);
723
724    category * 3 + above + left
725  }
726
727  pub fn write_tx_size_inter<W: Writer>(
728    &mut self, w: &mut W, bo: TileBlockOffset, bsize: BlockSize,
729    tx_size: TxSize, txfm_split: bool, tbx: usize, tby: usize, depth: usize,
730  ) {
731    if bo.0.x >= self.bc.blocks.cols() || bo.0.y >= self.bc.blocks.rows() {
732      return;
733    }
734    debug_assert!(self.bc.blocks[bo].is_inter());
735    debug_assert!(bsize > BlockSize::BLOCK_4X4);
736    debug_assert!(!tx_size.is_rect() || bsize.is_rect_tx_allowed());
737
738    if tx_size != TX_4X4 && depth < MAX_VARTX_DEPTH {
739      let ctx = self.txfm_partition_context(bo, bsize, tx_size, tbx, tby);
740      let cdf = &self.fc.txfm_partition_cdf[ctx];
741      symbol_with_update!(self, w, txfm_split as u32, cdf);
742    } else {
743      debug_assert!(!txfm_split);
744    }
745
746    if !txfm_split {
747      self.bc.update_tx_size_context(bo, tx_size.block_size(), tx_size, false);
748    } else {
749      // if txfm_split == true, split one level only
750      let split_tx_size = sub_tx_size_map[tx_size as usize];
751      let bw = bsize.width_mi() / split_tx_size.width_mi();
752      let bh = bsize.height_mi() / split_tx_size.height_mi();
753
754      for by in 0..bh {
755        for bx in 0..bw {
756          let tx_bo = TileBlockOffset(BlockOffset {
757            x: bo.0.x + bx * split_tx_size.width_mi(),
758            y: bo.0.y + by * split_tx_size.height_mi(),
759          });
760          self.write_tx_size_inter(
761            w,
762            tx_bo,
763            bsize,
764            split_tx_size,
765            false,
766            bx,
767            by,
768            depth + 1,
769          );
770        }
771      }
772    }
773  }
774
775  #[inline]
776  pub const fn get_txsize_entropy_ctx(tx_size: TxSize) -> usize {
777    (tx_size.sqr() as usize + tx_size.sqr_up() as usize + 1) >> 1
778  }
779
780  pub fn txb_init_levels<T: Coefficient>(
781    &self, coeffs: &[T], height: usize, levels: &mut [u8],
782    levels_stride: usize,
783  ) {
784    // Coefficients and levels are transposed from how they work in the spec
785    for (coeffs_col, levels_col) in
786      coeffs.chunks_exact(height).zip(levels.chunks_exact_mut(levels_stride))
787    {
788      for (coeff, level) in coeffs_col.iter().zip(levels_col) {
789        *level = coeff.abs().min(T::cast_from(127)).as_();
790      }
791    }
792  }
793
794  // Since the coefficients and levels are transposed in relation to how they
795  // work in the spec, use the log of block height in our calculations instead
796  // of block width.
797  #[inline]
798  pub const fn get_txb_bhl(tx_size: TxSize) -> usize {
799    av1_get_coded_tx_size(tx_size).height_log2()
800  }
801
802  /// Returns `(eob_pt, eob_extra)`
803  ///
804  /// # Panics
805  ///
806  /// - If `eob` is prior to the start of the group
807  #[inline]
808  pub fn get_eob_pos_token(eob: u16) -> (u32, u32) {
809    let t = if eob < 33 {
810      eob_to_pos_small[usize::from(eob)] as u32
811    } else {
812      let e = usize::from(cmp::min((eob - 1) >> 5, 16));
813      eob_to_pos_large[e] as u32
814    };
815    assert!(eob as i32 >= k_eob_group_start[t as usize] as i32);
816    let extra = eob as u32 - k_eob_group_start[t as usize] as u32;
817
818    (t, extra)
819  }
820
821  pub fn get_nz_mag(levels: &[u8], bhl: usize, tx_class: TxClass) -> usize {
822    // Levels are transposed from how they work in the spec
823
824    // May version.
825    // Note: AOMMIN(level, 3) is useless for decoder since level < 3.
826    let mut mag = cmp::min(3, levels[1]); // { 1, 0 }
827    mag += cmp::min(3, levels[(1 << bhl) + TX_PAD_HOR]); // { 0, 1 }
828
829    if tx_class == TX_CLASS_2D {
830      mag += cmp::min(3, levels[(1 << bhl) + TX_PAD_HOR + 1]); // { 1, 1 }
831      mag += cmp::min(3, levels[2]); // { 2, 0 }
832      mag += cmp::min(3, levels[(2 << bhl) + (2 << TX_PAD_HOR_LOG2)]); // { 0, 2 }
833    } else if tx_class == TX_CLASS_VERT {
834      mag += cmp::min(3, levels[2]); // { 2, 0 }
835      mag += cmp::min(3, levels[3]); // { 3, 0 }
836      mag += cmp::min(3, levels[4]); // { 4, 0 }
837    } else {
838      mag += cmp::min(3, levels[(2 << bhl) + (2 << TX_PAD_HOR_LOG2)]); // { 0, 2 }
839      mag += cmp::min(3, levels[(3 << bhl) + (3 << TX_PAD_HOR_LOG2)]); // { 0, 3 }
840      mag += cmp::min(3, levels[(4 << bhl) + (4 << TX_PAD_HOR_LOG2)]); // { 0, 4 }
841    }
842
843    mag as usize
844  }
845
846  fn get_nz_map_ctx_from_stats(
847    stats: usize,
848    coeff_idx: usize, // raster order
849    bhl: usize,
850    tx_size: TxSize,
851    tx_class: TxClass,
852  ) -> usize {
853    if (tx_class as u32 | coeff_idx as u32) == 0 {
854      return 0;
855    };
856
857    // Coefficients are transposed from how they work in the spec
858    let col: usize = coeff_idx >> bhl;
859    let row: usize = coeff_idx - (col << bhl);
860
861    let ctx = ((stats + 1) >> 1).min(4);
862
863    ctx
864      + match tx_class {
865        TX_CLASS_2D => {
866          // This is the algorithm to generate table av1_nz_map_ctx_offset[].
867          // const int width = tx_size_wide[tx_size];
868          // const int height = tx_size_high[tx_size];
869          // if (width < height) {
870          //   if (row < 2) return 11 + ctx;
871          // } else if (width > height) {
872          //   if (col < 2) return 16 + ctx;
873          // }
874          // if (row + col < 2) return ctx + 1;
875          // if (row + col < 4) return 5 + ctx + 1;
876          // return 21 + ctx;
877          av1_nz_map_ctx_offset[tx_size as usize][cmp::min(row, 4)]
878            [cmp::min(col, 4)] as usize
879        }
880        TX_CLASS_HORIZ => nz_map_ctx_offset_1d[col],
881        TX_CLASS_VERT => nz_map_ctx_offset_1d[row],
882      }
883  }
884
885  fn get_nz_map_ctx(
886    levels: &[u8], coeff_idx: usize, bhl: usize, area: usize, scan_idx: usize,
887    is_eob: bool, tx_size: TxSize, tx_class: TxClass,
888  ) -> usize {
889    if is_eob {
890      if scan_idx == 0 {
891        return 0;
892      }
893      if scan_idx <= area / 8 {
894        return 1;
895      }
896      if scan_idx <= area / 4 {
897        return 2;
898      }
899      return 3;
900    }
901
902    // Levels are transposed from how they work in the spec
903    let padded_idx = coeff_idx + ((coeff_idx >> bhl) << TX_PAD_HOR_LOG2);
904    let stats = Self::get_nz_mag(&levels[padded_idx..], bhl, tx_class);
905
906    Self::get_nz_map_ctx_from_stats(stats, coeff_idx, bhl, tx_size, tx_class)
907  }
908
909  /// `coeff_contexts_no_scan` is not in the scan order.
910  /// Value for `pos = scan[i]` is at `coeff[i]`, not at `coeff[pos]`.
911  pub fn get_nz_map_contexts<'c>(
912    &self, levels: &mut [u8], scan: &[u16], eob: u16, tx_size: TxSize,
913    tx_class: TxClass, coeff_contexts_no_scan: &'c mut [MaybeUninit<i8>],
914  ) -> &'c mut [i8] {
915    let bhl = Self::get_txb_bhl(tx_size);
916    let area = av1_get_coded_tx_size(tx_size).area();
917
918    let scan = &scan[..usize::from(eob)];
919    let coeffs = &mut coeff_contexts_no_scan[..usize::from(eob)];
920    for (i, (coeff, pos)) in
921      coeffs.iter_mut().zip(scan.iter().copied()).enumerate()
922    {
923      coeff.write(Self::get_nz_map_ctx(
924        levels,
925        pos as usize,
926        bhl,
927        area,
928        i,
929        i == usize::from(eob) - 1,
930        tx_size,
931        tx_class,
932      ) as i8);
933    }
934    // SAFETY: every element has been initialized
935    unsafe { slice_assume_init_mut(coeffs) }
936  }
937
938  pub fn get_br_ctx(
939    levels: &[u8],
940    coeff_idx: usize, // raster order
941    bhl: usize,
942    tx_class: TxClass,
943  ) -> usize {
944    // Coefficients and levels are transposed from how they work in the spec
945    let col: usize = coeff_idx >> bhl;
946    let row: usize = coeff_idx - (col << bhl);
947    let stride: usize = (1 << bhl) + TX_PAD_HOR;
948    let pos: usize = col * stride + row;
949    let mut mag: usize = (levels[pos + 1] + levels[pos + stride]) as usize;
950
951    match tx_class {
952      TX_CLASS_2D => {
953        mag += levels[pos + stride + 1] as usize;
954        mag = cmp::min((mag + 1) >> 1, 6);
955        if coeff_idx == 0 {
956          return mag;
957        }
958        if (row < 2) && (col < 2) {
959          return mag + 7;
960        }
961      }
962      TX_CLASS_HORIZ => {
963        mag += levels[pos + (stride << 1)] as usize;
964        mag = cmp::min((mag + 1) >> 1, 6);
965        if coeff_idx == 0 {
966          return mag;
967        }
968        if col == 0 {
969          return mag + 7;
970        }
971      }
972      TX_CLASS_VERT => {
973        mag += levels[pos + 2] as usize;
974        mag = cmp::min((mag + 1) >> 1, 6);
975        if coeff_idx == 0 {
976          return mag;
977        }
978        if row == 0 {
979          return mag + 7;
980        }
981      }
982    }
983
984    mag + 14
985  }
986}