rav1e/transform/
mod.rs

1// Copyright (c) 2017-2022, The rav1e contributors. All rights reserved
2//
3// This source code is subject to the terms of the BSD 2 Clause License and
4// the Alliance for Open Media Patent License 1.0. If the BSD 2 Clause License
5// was not distributed with this source code in the LICENSE file, you can
6// obtain it at www.aomedia.org/license/software. If the Alliance for Open
7// Media Patent License 1.0 was not distributed with this source code in the
8// PATENTS file, you can obtain it at www.aomedia.org/license/patent.
9
10#![allow(non_camel_case_types)]
11#![allow(dead_code)]
12
13#[macro_use]
14pub mod forward_shared;
15
16pub use self::forward::forward_transform;
17pub use self::inverse::inverse_transform_add;
18
19use crate::context::MI_SIZE_LOG2;
20use crate::partition::{BlockSize, BlockSize::*};
21use crate::util::*;
22
23use TxSize::*;
24
25pub mod forward;
26pub mod inverse;
27
28pub static RAV1E_TX_TYPES: &[TxType] = &[
29  TxType::DCT_DCT,
30  TxType::ADST_DCT,
31  TxType::DCT_ADST,
32  TxType::ADST_ADST,
33  // TODO: Add a speed setting for FLIPADST
34  // TxType::FLIPADST_DCT,
35  // TxType::DCT_FLIPADST,
36  // TxType::FLIPADST_FLIPADST,
37  // TxType::ADST_FLIPADST,
38  // TxType::FLIPADST_ADST,
39  TxType::IDTX,
40  TxType::V_DCT,
41  TxType::H_DCT,
42  //TxType::V_FLIPADST,
43  //TxType::H_FLIPADST,
44];
45
46pub mod consts {
47  pub static SQRT2_BITS: usize = 12;
48  pub static SQRT2: i32 = 5793; // 2^12 * sqrt(2)
49  pub static INV_SQRT2: i32 = 2896; // 2^12 / sqrt(2)
50}
51
52pub const TX_TYPES: usize = 16;
53pub const TX_TYPES_PLUS_LL: usize = 17;
54
55#[derive(Debug, Copy, Clone, PartialEq, PartialOrd, Eq, Ord)]
56pub enum TxType {
57  DCT_DCT = 0,   // DCT  in both horizontal and vertical
58  ADST_DCT = 1,  // ADST in vertical, DCT in horizontal
59  DCT_ADST = 2,  // DCT  in vertical, ADST in horizontal
60  ADST_ADST = 3, // ADST in both directions
61  FLIPADST_DCT = 4,
62  DCT_FLIPADST = 5,
63  FLIPADST_FLIPADST = 6,
64  ADST_FLIPADST = 7,
65  FLIPADST_ADST = 8,
66  IDTX = 9,
67  V_DCT = 10,
68  H_DCT = 11,
69  V_ADST = 12,
70  H_ADST = 13,
71  V_FLIPADST = 14,
72  H_FLIPADST = 15,
73  WHT_WHT = 16,
74}
75
76impl TxType {
77  /// Compute transform type for inter chroma.
78  ///
79  /// <https://aomediacodec.github.io/av1-spec/#compute-transform-type-function>
80  #[inline]
81  pub fn uv_inter(self, uv_tx_size: TxSize) -> Self {
82    use TxType::*;
83    if uv_tx_size.sqr_up() == TX_32X32 {
84      match self {
85        IDTX => IDTX,
86        _ => DCT_DCT,
87      }
88    } else if uv_tx_size.sqr() == TX_16X16 {
89      match self {
90        V_ADST | H_ADST | V_FLIPADST | H_FLIPADST => DCT_DCT,
91        _ => self,
92      }
93    } else {
94      self
95    }
96  }
97}
98
99/// Transform Size
100#[derive(Copy, Clone, Debug, PartialEq, PartialOrd, Eq, Ord)]
101pub enum TxSize {
102  TX_4X4,
103  TX_8X8,
104  TX_16X16,
105  TX_32X32,
106  TX_64X64,
107
108  TX_4X8,
109  TX_8X4,
110  TX_8X16,
111  TX_16X8,
112  TX_16X32,
113  TX_32X16,
114  TX_32X64,
115  TX_64X32,
116
117  TX_4X16,
118  TX_16X4,
119  TX_8X32,
120  TX_32X8,
121  TX_16X64,
122  TX_64X16,
123}
124
125impl TxSize {
126  /// Number of square transform sizes
127  pub const TX_SIZES: usize = 5;
128
129  /// Number of transform sizes (including non-square sizes)
130  pub const TX_SIZES_ALL: usize = 14 + 5;
131
132  #[inline]
133  pub const fn width(self) -> usize {
134    1 << self.width_log2()
135  }
136
137  #[inline]
138  pub const fn width_log2(self) -> usize {
139    match self {
140      TX_4X4 | TX_4X8 | TX_4X16 => 2,
141      TX_8X8 | TX_8X4 | TX_8X16 | TX_8X32 => 3,
142      TX_16X16 | TX_16X8 | TX_16X32 | TX_16X4 | TX_16X64 => 4,
143      TX_32X32 | TX_32X16 | TX_32X64 | TX_32X8 => 5,
144      TX_64X64 | TX_64X32 | TX_64X16 => 6,
145    }
146  }
147
148  #[inline]
149  pub const fn width_index(self) -> usize {
150    self.width_log2() - TX_4X4.width_log2()
151  }
152
153  #[inline]
154  pub const fn height(self) -> usize {
155    1 << self.height_log2()
156  }
157
158  #[inline]
159  pub const fn height_log2(self) -> usize {
160    match self {
161      TX_4X4 | TX_8X4 | TX_16X4 => 2,
162      TX_8X8 | TX_4X8 | TX_16X8 | TX_32X8 => 3,
163      TX_16X16 | TX_8X16 | TX_32X16 | TX_4X16 | TX_64X16 => 4,
164      TX_32X32 | TX_16X32 | TX_64X32 | TX_8X32 => 5,
165      TX_64X64 | TX_32X64 | TX_16X64 => 6,
166    }
167  }
168
169  #[inline]
170  pub const fn height_index(self) -> usize {
171    self.height_log2() - TX_4X4.height_log2()
172  }
173
174  #[inline]
175  pub const fn width_mi(self) -> usize {
176    self.width() >> MI_SIZE_LOG2
177  }
178
179  #[inline]
180  pub const fn area(self) -> usize {
181    1 << self.area_log2()
182  }
183
184  #[inline]
185  pub const fn area_log2(self) -> usize {
186    self.width_log2() + self.height_log2()
187  }
188
189  #[inline]
190  pub const fn height_mi(self) -> usize {
191    self.height() >> MI_SIZE_LOG2
192  }
193
194  #[inline]
195  pub const fn block_size(self) -> BlockSize {
196    match self {
197      TX_4X4 => BLOCK_4X4,
198      TX_8X8 => BLOCK_8X8,
199      TX_16X16 => BLOCK_16X16,
200      TX_32X32 => BLOCK_32X32,
201      TX_64X64 => BLOCK_64X64,
202      TX_4X8 => BLOCK_4X8,
203      TX_8X4 => BLOCK_8X4,
204      TX_8X16 => BLOCK_8X16,
205      TX_16X8 => BLOCK_16X8,
206      TX_16X32 => BLOCK_16X32,
207      TX_32X16 => BLOCK_32X16,
208      TX_32X64 => BLOCK_32X64,
209      TX_64X32 => BLOCK_64X32,
210      TX_4X16 => BLOCK_4X16,
211      TX_16X4 => BLOCK_16X4,
212      TX_8X32 => BLOCK_8X32,
213      TX_32X8 => BLOCK_32X8,
214      TX_16X64 => BLOCK_16X64,
215      TX_64X16 => BLOCK_64X16,
216    }
217  }
218
219  #[inline]
220  pub const fn sqr(self) -> TxSize {
221    match self {
222      TX_4X4 | TX_4X8 | TX_8X4 | TX_4X16 | TX_16X4 => TX_4X4,
223      TX_8X8 | TX_8X16 | TX_16X8 | TX_8X32 | TX_32X8 => TX_8X8,
224      TX_16X16 | TX_16X32 | TX_32X16 | TX_16X64 | TX_64X16 => TX_16X16,
225      TX_32X32 | TX_32X64 | TX_64X32 => TX_32X32,
226      TX_64X64 => TX_64X64,
227    }
228  }
229
230  #[inline]
231  pub const fn sqr_up(self) -> TxSize {
232    match self {
233      TX_4X4 => TX_4X4,
234      TX_8X8 | TX_4X8 | TX_8X4 => TX_8X8,
235      TX_16X16 | TX_8X16 | TX_16X8 | TX_4X16 | TX_16X4 => TX_16X16,
236      TX_32X32 | TX_16X32 | TX_32X16 | TX_8X32 | TX_32X8 => TX_32X32,
237      TX_64X64 | TX_32X64 | TX_64X32 | TX_16X64 | TX_64X16 => TX_64X64,
238    }
239  }
240
241  #[inline]
242  pub fn by_dims(w: usize, h: usize) -> TxSize {
243    match (w, h) {
244      (4, 4) => TX_4X4,
245      (8, 8) => TX_8X8,
246      (16, 16) => TX_16X16,
247      (32, 32) => TX_32X32,
248      (64, 64) => TX_64X64,
249      (4, 8) => TX_4X8,
250      (8, 4) => TX_8X4,
251      (8, 16) => TX_8X16,
252      (16, 8) => TX_16X8,
253      (16, 32) => TX_16X32,
254      (32, 16) => TX_32X16,
255      (32, 64) => TX_32X64,
256      (64, 32) => TX_64X32,
257      (4, 16) => TX_4X16,
258      (16, 4) => TX_16X4,
259      (8, 32) => TX_8X32,
260      (32, 8) => TX_32X8,
261      (16, 64) => TX_16X64,
262      (64, 16) => TX_64X16,
263      _ => unreachable!(),
264    }
265  }
266
267  #[inline]
268  pub const fn is_rect(self) -> bool {
269    self.width_log2() != self.height_log2()
270  }
271
272  /// Returns log2(width / height), e.g. `TX_16x4` -> log2(16 / 4) = 2
273  #[inline]
274  pub const fn rect_ratio_log2(self) -> i8 {
275    self.width_log2() as i8 - self.height_log2() as i8
276  }
277}
278
279#[derive(Copy, Clone, PartialEq, Eq, PartialOrd)]
280pub enum TxSet {
281  // DCT only
282  TX_SET_DCTONLY,
283  // DCT + Identity only
284  TX_SET_INTER_3, // TX_SET_DCT_IDTX
285  // Discrete Trig transforms w/o flip (4) + Identity (1)
286  TX_SET_INTRA_2, // TX_SET_DTT4_IDTX
287  // Discrete Trig transforms w/o flip (4) + Identity (1) + 1D Hor/vert DCT (2)
288  TX_SET_INTRA_1, // TX_SET_DTT4_IDTX_1DDCT
289  // Discrete Trig transforms w/ flip (9) + Identity (1) + 1D Hor/Ver DCT (2)
290  TX_SET_INTER_2, // TX_SET_DTT9_IDTX_1DDCT
291  // Discrete Trig transforms w/ flip (9) + Identity (1) + 1D Hor/Ver (6)
292  TX_SET_INTER_1, // TX_SET_ALL16
293}
294
295// performs half a butterfly
296#[inline]
297const fn half_btf(w0: i32, in0: i32, w1: i32, in1: i32, bit: usize) -> i32 {
298  // Ensure defined behaviour for when w0*in0 + w1*in1 is negative and
299  //   overflows, but w0*in0 + w1*in1 + rounding isn't.
300  let result = (w0 * in0).wrapping_add(w1 * in1);
301  // Implement a version of round_shift with wrapping
302  if bit == 0 {
303    result
304  } else {
305    result.wrapping_add(1 << (bit - 1)) >> bit
306  }
307}
308
309// clamps value to a signed integer type of bit bits
310#[inline]
311fn clamp_value(value: i32, bit: usize) -> i32 {
312  let max_value: i32 = ((1i64 << (bit - 1)) - 1) as i32;
313  let min_value: i32 = (-(1i64 << (bit - 1))) as i32;
314  clamp(value, min_value, max_value)
315}
316
317pub fn av1_round_shift_array(arr: &mut [i32], size: usize, bit: i8) {
318  if bit == 0 {
319    return;
320  }
321  if bit > 0 {
322    let bit = bit as usize;
323    arr.iter_mut().take(size).for_each(|i| {
324      *i = round_shift(*i, bit);
325    })
326  } else {
327    arr.iter_mut().take(size).for_each(|i| {
328      *i <<= -bit;
329    })
330  }
331}
332
333#[derive(Debug, Clone, Copy)]
334enum TxType1D {
335  DCT,
336  ADST,
337  FLIPADST,
338  IDTX,
339  WHT,
340}
341
342const fn get_1d_tx_types(tx_type: TxType) -> (TxType1D, TxType1D) {
343  match tx_type {
344    TxType::DCT_DCT => (TxType1D::DCT, TxType1D::DCT),
345    TxType::ADST_DCT => (TxType1D::ADST, TxType1D::DCT),
346    TxType::DCT_ADST => (TxType1D::DCT, TxType1D::ADST),
347    TxType::ADST_ADST => (TxType1D::ADST, TxType1D::ADST),
348    TxType::FLIPADST_DCT => (TxType1D::FLIPADST, TxType1D::DCT),
349    TxType::DCT_FLIPADST => (TxType1D::DCT, TxType1D::FLIPADST),
350    TxType::FLIPADST_FLIPADST => (TxType1D::FLIPADST, TxType1D::FLIPADST),
351    TxType::ADST_FLIPADST => (TxType1D::ADST, TxType1D::FLIPADST),
352    TxType::FLIPADST_ADST => (TxType1D::FLIPADST, TxType1D::ADST),
353    TxType::IDTX => (TxType1D::IDTX, TxType1D::IDTX),
354    TxType::V_DCT => (TxType1D::DCT, TxType1D::IDTX),
355    TxType::H_DCT => (TxType1D::IDTX, TxType1D::DCT),
356    TxType::V_ADST => (TxType1D::ADST, TxType1D::IDTX),
357    TxType::H_ADST => (TxType1D::IDTX, TxType1D::ADST),
358    TxType::V_FLIPADST => (TxType1D::FLIPADST, TxType1D::IDTX),
359    TxType::H_FLIPADST => (TxType1D::IDTX, TxType1D::FLIPADST),
360    TxType::WHT_WHT => (TxType1D::WHT, TxType1D::WHT),
361  }
362}
363
364const VTX_TAB: [TxType1D; TX_TYPES_PLUS_LL] = [
365  TxType1D::DCT,
366  TxType1D::ADST,
367  TxType1D::DCT,
368  TxType1D::ADST,
369  TxType1D::FLIPADST,
370  TxType1D::DCT,
371  TxType1D::FLIPADST,
372  TxType1D::ADST,
373  TxType1D::FLIPADST,
374  TxType1D::IDTX,
375  TxType1D::DCT,
376  TxType1D::IDTX,
377  TxType1D::ADST,
378  TxType1D::IDTX,
379  TxType1D::FLIPADST,
380  TxType1D::IDTX,
381  TxType1D::WHT,
382];
383
384const HTX_TAB: [TxType1D; TX_TYPES_PLUS_LL] = [
385  TxType1D::DCT,
386  TxType1D::DCT,
387  TxType1D::ADST,
388  TxType1D::ADST,
389  TxType1D::DCT,
390  TxType1D::FLIPADST,
391  TxType1D::FLIPADST,
392  TxType1D::FLIPADST,
393  TxType1D::ADST,
394  TxType1D::IDTX,
395  TxType1D::IDTX,
396  TxType1D::DCT,
397  TxType1D::IDTX,
398  TxType1D::ADST,
399  TxType1D::IDTX,
400  TxType1D::FLIPADST,
401  TxType1D::WHT,
402];
403
404#[inline]
405pub const fn valid_av1_transform(tx_size: TxSize, tx_type: TxType) -> bool {
406  let size_sq = tx_size.sqr_up();
407  use TxSize::*;
408  use TxType::*;
409  match (size_sq, tx_type) {
410    (TX_64X64, DCT_DCT) => true,
411    (TX_64X64, _) => false,
412    (TX_32X32, DCT_DCT) => true,
413    (TX_32X32, IDTX) => true,
414    (TX_32X32, _) => false,
415    (_, _) => true,
416  }
417}
418
419#[cfg(any(test, feature = "bench"))]
420pub fn get_valid_txfm_types(tx_size: TxSize) -> &'static [TxType] {
421  let size_sq = tx_size.sqr_up();
422  use TxType::*;
423  if size_sq == TxSize::TX_64X64 {
424    &[DCT_DCT]
425  } else if size_sq == TxSize::TX_32X32 {
426    &[DCT_DCT, IDTX]
427  } else if size_sq == TxSize::TX_4X4 {
428    &[
429      DCT_DCT,
430      ADST_DCT,
431      DCT_ADST,
432      ADST_ADST,
433      FLIPADST_DCT,
434      DCT_FLIPADST,
435      FLIPADST_FLIPADST,
436      ADST_FLIPADST,
437      FLIPADST_ADST,
438      IDTX,
439      V_DCT,
440      H_DCT,
441      V_ADST,
442      H_ADST,
443      V_FLIPADST,
444      H_FLIPADST,
445      WHT_WHT,
446    ]
447  } else {
448    &[
449      DCT_DCT,
450      ADST_DCT,
451      DCT_ADST,
452      ADST_ADST,
453      FLIPADST_DCT,
454      DCT_FLIPADST,
455      FLIPADST_FLIPADST,
456      ADST_FLIPADST,
457      FLIPADST_ADST,
458      IDTX,
459      V_DCT,
460      H_DCT,
461      V_ADST,
462      H_ADST,
463      V_FLIPADST,
464      H_FLIPADST,
465    ]
466  }
467}
468
469#[cfg(test)]
470mod test {
471  use super::TxType::*;
472  use super::*;
473  use crate::context::av1_get_coded_tx_size;
474  use crate::cpu_features::CpuFeatureLevel;
475  use crate::frame::*;
476  use rand::random;
477  use std::mem::MaybeUninit;
478
479  fn test_roundtrip<T: Pixel>(
480    tx_size: TxSize, tx_type: TxType, tolerance: i16,
481  ) {
482    let cpu = CpuFeatureLevel::default();
483
484    let coeff_area: usize = av1_get_coded_tx_size(tx_size).area();
485    let mut src_storage = [T::cast_from(0); 64 * 64];
486    let src = &mut src_storage[..tx_size.area()];
487    let mut dst = Plane::from_slice(
488      &[T::zero(); 64 * 64][..tx_size.area()],
489      tx_size.width(),
490    );
491    let mut res_storage = [0i16; 64 * 64];
492    let res = &mut res_storage[..tx_size.area()];
493    let mut freq_storage = [MaybeUninit::uninit(); 64 * 64];
494    let freq = &mut freq_storage[..tx_size.area()];
495    for ((r, s), d) in
496      res.iter_mut().zip(src.iter_mut()).zip(dst.data.iter_mut())
497    {
498      *s = T::cast_from(random::<u8>());
499      *d = T::cast_from(random::<u8>());
500      *r = i16::cast_from(*s) - i16::cast_from(*d);
501    }
502    forward_transform(res, freq, tx_size.width(), tx_size, tx_type, 8, cpu);
503    // SAFETY: forward_transform initialized freq
504    let freq = unsafe { slice_assume_init_mut(freq) };
505    inverse_transform_add(
506      freq,
507      &mut dst.as_region_mut(),
508      coeff_area.try_into().unwrap(),
509      tx_size,
510      tx_type,
511      8,
512      cpu,
513    );
514
515    for (s, d) in src.iter().zip(dst.data.iter()) {
516      assert!(i16::abs(i16::cast_from(*s) - i16::cast_from(*d)) <= tolerance);
517    }
518  }
519
520  #[test]
521  fn log_tx_ratios() {
522    let combinations = [
523      (TxSize::TX_4X4, 0),
524      (TxSize::TX_8X8, 0),
525      (TxSize::TX_16X16, 0),
526      (TxSize::TX_32X32, 0),
527      (TxSize::TX_64X64, 0),
528      (TxSize::TX_4X8, -1),
529      (TxSize::TX_8X4, 1),
530      (TxSize::TX_8X16, -1),
531      (TxSize::TX_16X8, 1),
532      (TxSize::TX_16X32, -1),
533      (TxSize::TX_32X16, 1),
534      (TxSize::TX_32X64, -1),
535      (TxSize::TX_64X32, 1),
536      (TxSize::TX_4X16, -2),
537      (TxSize::TX_16X4, 2),
538      (TxSize::TX_8X32, -2),
539      (TxSize::TX_32X8, 2),
540      (TxSize::TX_16X64, -2),
541      (TxSize::TX_64X16, 2),
542    ];
543
544    for &(tx_size, expected) in combinations.iter() {
545      println!(
546        "Testing combination {:?}, {:?}",
547        tx_size.width(),
548        tx_size.height()
549      );
550      assert!(tx_size.rect_ratio_log2() == expected);
551    }
552  }
553
554  fn roundtrips<T: Pixel>() {
555    let combinations = [
556      (TX_4X4, WHT_WHT, 0),
557      (TX_4X4, DCT_DCT, 0),
558      (TX_4X4, ADST_DCT, 0),
559      (TX_4X4, DCT_ADST, 0),
560      (TX_4X4, ADST_ADST, 0),
561      (TX_4X4, FLIPADST_DCT, 0),
562      (TX_4X4, DCT_FLIPADST, 0),
563      (TX_4X4, IDTX, 0),
564      (TX_4X4, V_DCT, 0),
565      (TX_4X4, H_DCT, 0),
566      (TX_4X4, V_ADST, 0),
567      (TX_4X4, H_ADST, 0),
568      (TX_8X8, DCT_DCT, 1),
569      (TX_8X8, ADST_DCT, 1),
570      (TX_8X8, DCT_ADST, 1),
571      (TX_8X8, ADST_ADST, 1),
572      (TX_8X8, FLIPADST_DCT, 1),
573      (TX_8X8, DCT_FLIPADST, 1),
574      (TX_8X8, IDTX, 0),
575      (TX_8X8, V_DCT, 0),
576      (TX_8X8, H_DCT, 0),
577      (TX_8X8, V_ADST, 0),
578      (TX_8X8, H_ADST, 1),
579      (TX_16X16, DCT_DCT, 1),
580      (TX_16X16, ADST_DCT, 1),
581      (TX_16X16, DCT_ADST, 1),
582      (TX_16X16, ADST_ADST, 1),
583      (TX_16X16, FLIPADST_DCT, 1),
584      (TX_16X16, DCT_FLIPADST, 1),
585      (TX_16X16, IDTX, 0),
586      (TX_16X16, V_DCT, 1),
587      (TX_16X16, H_DCT, 1),
588      // 32x transforms only use DCT_DCT and IDTX
589      (TX_32X32, DCT_DCT, 2),
590      (TX_32X32, IDTX, 0),
591      // 64x transforms only use DCT_DCT and IDTX
592      //(TX_64X64, DCT_DCT, 0),
593      (TX_4X8, DCT_DCT, 1),
594      (TX_8X4, DCT_DCT, 1),
595      (TX_4X16, DCT_DCT, 1),
596      (TX_16X4, DCT_DCT, 1),
597      (TX_8X16, DCT_DCT, 1),
598      (TX_16X8, DCT_DCT, 1),
599      (TX_8X32, DCT_DCT, 2),
600      (TX_32X8, DCT_DCT, 2),
601      (TX_16X32, DCT_DCT, 2),
602      (TX_32X16, DCT_DCT, 2),
603    ];
604    for &(tx_size, tx_type, tolerance) in combinations.iter() {
605      println!("Testing combination {:?}, {:?}", tx_size, tx_type);
606      test_roundtrip::<T>(tx_size, tx_type, tolerance);
607    }
608  }
609
610  #[test]
611  fn roundtrips_u8() {
612    roundtrips::<u8>();
613  }
614
615  #[test]
616  fn roundtrips_u16() {
617    roundtrips::<u16>();
618  }
619}