1#![allow(non_camel_case_types)]
11#![allow(dead_code)]
12
13#[macro_use]
14pub mod forward_shared;
15
16pub use self::forward::forward_transform;
17pub use self::inverse::inverse_transform_add;
18
19use crate::context::MI_SIZE_LOG2;
20use crate::partition::{BlockSize, BlockSize::*};
21use crate::util::*;
22
23use TxSize::*;
24
25pub mod forward;
26pub mod inverse;
27
28pub static RAV1E_TX_TYPES: &[TxType] = &[
29 TxType::DCT_DCT,
30 TxType::ADST_DCT,
31 TxType::DCT_ADST,
32 TxType::ADST_ADST,
33 TxType::IDTX,
40 TxType::V_DCT,
41 TxType::H_DCT,
42 ];
45
46pub mod consts {
47 pub static SQRT2_BITS: usize = 12;
48 pub static SQRT2: i32 = 5793; pub static INV_SQRT2: i32 = 2896; }
51
52pub const TX_TYPES: usize = 16;
53pub const TX_TYPES_PLUS_LL: usize = 17;
54
55#[derive(Debug, Copy, Clone, PartialEq, PartialOrd, Eq, Ord)]
56pub enum TxType {
57 DCT_DCT = 0, ADST_DCT = 1, DCT_ADST = 2, ADST_ADST = 3, FLIPADST_DCT = 4,
62 DCT_FLIPADST = 5,
63 FLIPADST_FLIPADST = 6,
64 ADST_FLIPADST = 7,
65 FLIPADST_ADST = 8,
66 IDTX = 9,
67 V_DCT = 10,
68 H_DCT = 11,
69 V_ADST = 12,
70 H_ADST = 13,
71 V_FLIPADST = 14,
72 H_FLIPADST = 15,
73 WHT_WHT = 16,
74}
75
76impl TxType {
77 #[inline]
81 pub fn uv_inter(self, uv_tx_size: TxSize) -> Self {
82 use TxType::*;
83 if uv_tx_size.sqr_up() == TX_32X32 {
84 match self {
85 IDTX => IDTX,
86 _ => DCT_DCT,
87 }
88 } else if uv_tx_size.sqr() == TX_16X16 {
89 match self {
90 V_ADST | H_ADST | V_FLIPADST | H_FLIPADST => DCT_DCT,
91 _ => self,
92 }
93 } else {
94 self
95 }
96 }
97}
98
99#[derive(Copy, Clone, Debug, PartialEq, PartialOrd, Eq, Ord)]
101pub enum TxSize {
102 TX_4X4,
103 TX_8X8,
104 TX_16X16,
105 TX_32X32,
106 TX_64X64,
107
108 TX_4X8,
109 TX_8X4,
110 TX_8X16,
111 TX_16X8,
112 TX_16X32,
113 TX_32X16,
114 TX_32X64,
115 TX_64X32,
116
117 TX_4X16,
118 TX_16X4,
119 TX_8X32,
120 TX_32X8,
121 TX_16X64,
122 TX_64X16,
123}
124
125impl TxSize {
126 pub const TX_SIZES: usize = 5;
128
129 pub const TX_SIZES_ALL: usize = 14 + 5;
131
132 #[inline]
133 pub const fn width(self) -> usize {
134 1 << self.width_log2()
135 }
136
137 #[inline]
138 pub const fn width_log2(self) -> usize {
139 match self {
140 TX_4X4 | TX_4X8 | TX_4X16 => 2,
141 TX_8X8 | TX_8X4 | TX_8X16 | TX_8X32 => 3,
142 TX_16X16 | TX_16X8 | TX_16X32 | TX_16X4 | TX_16X64 => 4,
143 TX_32X32 | TX_32X16 | TX_32X64 | TX_32X8 => 5,
144 TX_64X64 | TX_64X32 | TX_64X16 => 6,
145 }
146 }
147
148 #[inline]
149 pub const fn width_index(self) -> usize {
150 self.width_log2() - TX_4X4.width_log2()
151 }
152
153 #[inline]
154 pub const fn height(self) -> usize {
155 1 << self.height_log2()
156 }
157
158 #[inline]
159 pub const fn height_log2(self) -> usize {
160 match self {
161 TX_4X4 | TX_8X4 | TX_16X4 => 2,
162 TX_8X8 | TX_4X8 | TX_16X8 | TX_32X8 => 3,
163 TX_16X16 | TX_8X16 | TX_32X16 | TX_4X16 | TX_64X16 => 4,
164 TX_32X32 | TX_16X32 | TX_64X32 | TX_8X32 => 5,
165 TX_64X64 | TX_32X64 | TX_16X64 => 6,
166 }
167 }
168
169 #[inline]
170 pub const fn height_index(self) -> usize {
171 self.height_log2() - TX_4X4.height_log2()
172 }
173
174 #[inline]
175 pub const fn width_mi(self) -> usize {
176 self.width() >> MI_SIZE_LOG2
177 }
178
179 #[inline]
180 pub const fn area(self) -> usize {
181 1 << self.area_log2()
182 }
183
184 #[inline]
185 pub const fn area_log2(self) -> usize {
186 self.width_log2() + self.height_log2()
187 }
188
189 #[inline]
190 pub const fn height_mi(self) -> usize {
191 self.height() >> MI_SIZE_LOG2
192 }
193
194 #[inline]
195 pub const fn block_size(self) -> BlockSize {
196 match self {
197 TX_4X4 => BLOCK_4X4,
198 TX_8X8 => BLOCK_8X8,
199 TX_16X16 => BLOCK_16X16,
200 TX_32X32 => BLOCK_32X32,
201 TX_64X64 => BLOCK_64X64,
202 TX_4X8 => BLOCK_4X8,
203 TX_8X4 => BLOCK_8X4,
204 TX_8X16 => BLOCK_8X16,
205 TX_16X8 => BLOCK_16X8,
206 TX_16X32 => BLOCK_16X32,
207 TX_32X16 => BLOCK_32X16,
208 TX_32X64 => BLOCK_32X64,
209 TX_64X32 => BLOCK_64X32,
210 TX_4X16 => BLOCK_4X16,
211 TX_16X4 => BLOCK_16X4,
212 TX_8X32 => BLOCK_8X32,
213 TX_32X8 => BLOCK_32X8,
214 TX_16X64 => BLOCK_16X64,
215 TX_64X16 => BLOCK_64X16,
216 }
217 }
218
219 #[inline]
220 pub const fn sqr(self) -> TxSize {
221 match self {
222 TX_4X4 | TX_4X8 | TX_8X4 | TX_4X16 | TX_16X4 => TX_4X4,
223 TX_8X8 | TX_8X16 | TX_16X8 | TX_8X32 | TX_32X8 => TX_8X8,
224 TX_16X16 | TX_16X32 | TX_32X16 | TX_16X64 | TX_64X16 => TX_16X16,
225 TX_32X32 | TX_32X64 | TX_64X32 => TX_32X32,
226 TX_64X64 => TX_64X64,
227 }
228 }
229
230 #[inline]
231 pub const fn sqr_up(self) -> TxSize {
232 match self {
233 TX_4X4 => TX_4X4,
234 TX_8X8 | TX_4X8 | TX_8X4 => TX_8X8,
235 TX_16X16 | TX_8X16 | TX_16X8 | TX_4X16 | TX_16X4 => TX_16X16,
236 TX_32X32 | TX_16X32 | TX_32X16 | TX_8X32 | TX_32X8 => TX_32X32,
237 TX_64X64 | TX_32X64 | TX_64X32 | TX_16X64 | TX_64X16 => TX_64X64,
238 }
239 }
240
241 #[inline]
242 pub fn by_dims(w: usize, h: usize) -> TxSize {
243 match (w, h) {
244 (4, 4) => TX_4X4,
245 (8, 8) => TX_8X8,
246 (16, 16) => TX_16X16,
247 (32, 32) => TX_32X32,
248 (64, 64) => TX_64X64,
249 (4, 8) => TX_4X8,
250 (8, 4) => TX_8X4,
251 (8, 16) => TX_8X16,
252 (16, 8) => TX_16X8,
253 (16, 32) => TX_16X32,
254 (32, 16) => TX_32X16,
255 (32, 64) => TX_32X64,
256 (64, 32) => TX_64X32,
257 (4, 16) => TX_4X16,
258 (16, 4) => TX_16X4,
259 (8, 32) => TX_8X32,
260 (32, 8) => TX_32X8,
261 (16, 64) => TX_16X64,
262 (64, 16) => TX_64X16,
263 _ => unreachable!(),
264 }
265 }
266
267 #[inline]
268 pub const fn is_rect(self) -> bool {
269 self.width_log2() != self.height_log2()
270 }
271
272 #[inline]
274 pub const fn rect_ratio_log2(self) -> i8 {
275 self.width_log2() as i8 - self.height_log2() as i8
276 }
277}
278
279#[derive(Copy, Clone, PartialEq, Eq, PartialOrd)]
280pub enum TxSet {
281 TX_SET_DCTONLY,
283 TX_SET_INTER_3, TX_SET_INTRA_2, TX_SET_INTRA_1, TX_SET_INTER_2, TX_SET_INTER_1, }
294
295#[inline]
297const fn half_btf(w0: i32, in0: i32, w1: i32, in1: i32, bit: usize) -> i32 {
298 let result = (w0 * in0).wrapping_add(w1 * in1);
301 if bit == 0 {
303 result
304 } else {
305 result.wrapping_add(1 << (bit - 1)) >> bit
306 }
307}
308
309#[inline]
311fn clamp_value(value: i32, bit: usize) -> i32 {
312 let max_value: i32 = ((1i64 << (bit - 1)) - 1) as i32;
313 let min_value: i32 = (-(1i64 << (bit - 1))) as i32;
314 clamp(value, min_value, max_value)
315}
316
317pub fn av1_round_shift_array(arr: &mut [i32], size: usize, bit: i8) {
318 if bit == 0 {
319 return;
320 }
321 if bit > 0 {
322 let bit = bit as usize;
323 arr.iter_mut().take(size).for_each(|i| {
324 *i = round_shift(*i, bit);
325 })
326 } else {
327 arr.iter_mut().take(size).for_each(|i| {
328 *i <<= -bit;
329 })
330 }
331}
332
333#[derive(Debug, Clone, Copy)]
334enum TxType1D {
335 DCT,
336 ADST,
337 FLIPADST,
338 IDTX,
339 WHT,
340}
341
342const fn get_1d_tx_types(tx_type: TxType) -> (TxType1D, TxType1D) {
343 match tx_type {
344 TxType::DCT_DCT => (TxType1D::DCT, TxType1D::DCT),
345 TxType::ADST_DCT => (TxType1D::ADST, TxType1D::DCT),
346 TxType::DCT_ADST => (TxType1D::DCT, TxType1D::ADST),
347 TxType::ADST_ADST => (TxType1D::ADST, TxType1D::ADST),
348 TxType::FLIPADST_DCT => (TxType1D::FLIPADST, TxType1D::DCT),
349 TxType::DCT_FLIPADST => (TxType1D::DCT, TxType1D::FLIPADST),
350 TxType::FLIPADST_FLIPADST => (TxType1D::FLIPADST, TxType1D::FLIPADST),
351 TxType::ADST_FLIPADST => (TxType1D::ADST, TxType1D::FLIPADST),
352 TxType::FLIPADST_ADST => (TxType1D::FLIPADST, TxType1D::ADST),
353 TxType::IDTX => (TxType1D::IDTX, TxType1D::IDTX),
354 TxType::V_DCT => (TxType1D::DCT, TxType1D::IDTX),
355 TxType::H_DCT => (TxType1D::IDTX, TxType1D::DCT),
356 TxType::V_ADST => (TxType1D::ADST, TxType1D::IDTX),
357 TxType::H_ADST => (TxType1D::IDTX, TxType1D::ADST),
358 TxType::V_FLIPADST => (TxType1D::FLIPADST, TxType1D::IDTX),
359 TxType::H_FLIPADST => (TxType1D::IDTX, TxType1D::FLIPADST),
360 TxType::WHT_WHT => (TxType1D::WHT, TxType1D::WHT),
361 }
362}
363
364const VTX_TAB: [TxType1D; TX_TYPES_PLUS_LL] = [
365 TxType1D::DCT,
366 TxType1D::ADST,
367 TxType1D::DCT,
368 TxType1D::ADST,
369 TxType1D::FLIPADST,
370 TxType1D::DCT,
371 TxType1D::FLIPADST,
372 TxType1D::ADST,
373 TxType1D::FLIPADST,
374 TxType1D::IDTX,
375 TxType1D::DCT,
376 TxType1D::IDTX,
377 TxType1D::ADST,
378 TxType1D::IDTX,
379 TxType1D::FLIPADST,
380 TxType1D::IDTX,
381 TxType1D::WHT,
382];
383
384const HTX_TAB: [TxType1D; TX_TYPES_PLUS_LL] = [
385 TxType1D::DCT,
386 TxType1D::DCT,
387 TxType1D::ADST,
388 TxType1D::ADST,
389 TxType1D::DCT,
390 TxType1D::FLIPADST,
391 TxType1D::FLIPADST,
392 TxType1D::FLIPADST,
393 TxType1D::ADST,
394 TxType1D::IDTX,
395 TxType1D::IDTX,
396 TxType1D::DCT,
397 TxType1D::IDTX,
398 TxType1D::ADST,
399 TxType1D::IDTX,
400 TxType1D::FLIPADST,
401 TxType1D::WHT,
402];
403
404#[inline]
405pub const fn valid_av1_transform(tx_size: TxSize, tx_type: TxType) -> bool {
406 let size_sq = tx_size.sqr_up();
407 use TxSize::*;
408 use TxType::*;
409 match (size_sq, tx_type) {
410 (TX_64X64, DCT_DCT) => true,
411 (TX_64X64, _) => false,
412 (TX_32X32, DCT_DCT) => true,
413 (TX_32X32, IDTX) => true,
414 (TX_32X32, _) => false,
415 (_, _) => true,
416 }
417}
418
419#[cfg(any(test, feature = "bench"))]
420pub fn get_valid_txfm_types(tx_size: TxSize) -> &'static [TxType] {
421 let size_sq = tx_size.sqr_up();
422 use TxType::*;
423 if size_sq == TxSize::TX_64X64 {
424 &[DCT_DCT]
425 } else if size_sq == TxSize::TX_32X32 {
426 &[DCT_DCT, IDTX]
427 } else if size_sq == TxSize::TX_4X4 {
428 &[
429 DCT_DCT,
430 ADST_DCT,
431 DCT_ADST,
432 ADST_ADST,
433 FLIPADST_DCT,
434 DCT_FLIPADST,
435 FLIPADST_FLIPADST,
436 ADST_FLIPADST,
437 FLIPADST_ADST,
438 IDTX,
439 V_DCT,
440 H_DCT,
441 V_ADST,
442 H_ADST,
443 V_FLIPADST,
444 H_FLIPADST,
445 WHT_WHT,
446 ]
447 } else {
448 &[
449 DCT_DCT,
450 ADST_DCT,
451 DCT_ADST,
452 ADST_ADST,
453 FLIPADST_DCT,
454 DCT_FLIPADST,
455 FLIPADST_FLIPADST,
456 ADST_FLIPADST,
457 FLIPADST_ADST,
458 IDTX,
459 V_DCT,
460 H_DCT,
461 V_ADST,
462 H_ADST,
463 V_FLIPADST,
464 H_FLIPADST,
465 ]
466 }
467}
468
469#[cfg(test)]
470mod test {
471 use super::TxType::*;
472 use super::*;
473 use crate::context::av1_get_coded_tx_size;
474 use crate::cpu_features::CpuFeatureLevel;
475 use crate::frame::*;
476 use rand::random;
477 use std::mem::MaybeUninit;
478
479 fn test_roundtrip<T: Pixel>(
480 tx_size: TxSize, tx_type: TxType, tolerance: i16,
481 ) {
482 let cpu = CpuFeatureLevel::default();
483
484 let coeff_area: usize = av1_get_coded_tx_size(tx_size).area();
485 let mut src_storage = [T::cast_from(0); 64 * 64];
486 let src = &mut src_storage[..tx_size.area()];
487 let mut dst = Plane::from_slice(
488 &[T::zero(); 64 * 64][..tx_size.area()],
489 tx_size.width(),
490 );
491 let mut res_storage = [0i16; 64 * 64];
492 let res = &mut res_storage[..tx_size.area()];
493 let mut freq_storage = [MaybeUninit::uninit(); 64 * 64];
494 let freq = &mut freq_storage[..tx_size.area()];
495 for ((r, s), d) in
496 res.iter_mut().zip(src.iter_mut()).zip(dst.data.iter_mut())
497 {
498 *s = T::cast_from(random::<u8>());
499 *d = T::cast_from(random::<u8>());
500 *r = i16::cast_from(*s) - i16::cast_from(*d);
501 }
502 forward_transform(res, freq, tx_size.width(), tx_size, tx_type, 8, cpu);
503 let freq = unsafe { slice_assume_init_mut(freq) };
505 inverse_transform_add(
506 freq,
507 &mut dst.as_region_mut(),
508 coeff_area.try_into().unwrap(),
509 tx_size,
510 tx_type,
511 8,
512 cpu,
513 );
514
515 for (s, d) in src.iter().zip(dst.data.iter()) {
516 assert!(i16::abs(i16::cast_from(*s) - i16::cast_from(*d)) <= tolerance);
517 }
518 }
519
520 #[test]
521 fn log_tx_ratios() {
522 let combinations = [
523 (TxSize::TX_4X4, 0),
524 (TxSize::TX_8X8, 0),
525 (TxSize::TX_16X16, 0),
526 (TxSize::TX_32X32, 0),
527 (TxSize::TX_64X64, 0),
528 (TxSize::TX_4X8, -1),
529 (TxSize::TX_8X4, 1),
530 (TxSize::TX_8X16, -1),
531 (TxSize::TX_16X8, 1),
532 (TxSize::TX_16X32, -1),
533 (TxSize::TX_32X16, 1),
534 (TxSize::TX_32X64, -1),
535 (TxSize::TX_64X32, 1),
536 (TxSize::TX_4X16, -2),
537 (TxSize::TX_16X4, 2),
538 (TxSize::TX_8X32, -2),
539 (TxSize::TX_32X8, 2),
540 (TxSize::TX_16X64, -2),
541 (TxSize::TX_64X16, 2),
542 ];
543
544 for &(tx_size, expected) in combinations.iter() {
545 println!(
546 "Testing combination {:?}, {:?}",
547 tx_size.width(),
548 tx_size.height()
549 );
550 assert!(tx_size.rect_ratio_log2() == expected);
551 }
552 }
553
554 fn roundtrips<T: Pixel>() {
555 let combinations = [
556 (TX_4X4, WHT_WHT, 0),
557 (TX_4X4, DCT_DCT, 0),
558 (TX_4X4, ADST_DCT, 0),
559 (TX_4X4, DCT_ADST, 0),
560 (TX_4X4, ADST_ADST, 0),
561 (TX_4X4, FLIPADST_DCT, 0),
562 (TX_4X4, DCT_FLIPADST, 0),
563 (TX_4X4, IDTX, 0),
564 (TX_4X4, V_DCT, 0),
565 (TX_4X4, H_DCT, 0),
566 (TX_4X4, V_ADST, 0),
567 (TX_4X4, H_ADST, 0),
568 (TX_8X8, DCT_DCT, 1),
569 (TX_8X8, ADST_DCT, 1),
570 (TX_8X8, DCT_ADST, 1),
571 (TX_8X8, ADST_ADST, 1),
572 (TX_8X8, FLIPADST_DCT, 1),
573 (TX_8X8, DCT_FLIPADST, 1),
574 (TX_8X8, IDTX, 0),
575 (TX_8X8, V_DCT, 0),
576 (TX_8X8, H_DCT, 0),
577 (TX_8X8, V_ADST, 0),
578 (TX_8X8, H_ADST, 1),
579 (TX_16X16, DCT_DCT, 1),
580 (TX_16X16, ADST_DCT, 1),
581 (TX_16X16, DCT_ADST, 1),
582 (TX_16X16, ADST_ADST, 1),
583 (TX_16X16, FLIPADST_DCT, 1),
584 (TX_16X16, DCT_FLIPADST, 1),
585 (TX_16X16, IDTX, 0),
586 (TX_16X16, V_DCT, 1),
587 (TX_16X16, H_DCT, 1),
588 (TX_32X32, DCT_DCT, 2),
590 (TX_32X32, IDTX, 0),
591 (TX_4X8, DCT_DCT, 1),
594 (TX_8X4, DCT_DCT, 1),
595 (TX_4X16, DCT_DCT, 1),
596 (TX_16X4, DCT_DCT, 1),
597 (TX_8X16, DCT_DCT, 1),
598 (TX_16X8, DCT_DCT, 1),
599 (TX_8X32, DCT_DCT, 2),
600 (TX_32X8, DCT_DCT, 2),
601 (TX_16X32, DCT_DCT, 2),
602 (TX_32X16, DCT_DCT, 2),
603 ];
604 for &(tx_size, tx_type, tolerance) in combinations.iter() {
605 println!("Testing combination {:?}, {:?}", tx_size, tx_type);
606 test_roundtrip::<T>(tx_size, tx_type, tolerance);
607 }
608 }
609
610 #[test]
611 fn roundtrips_u8() {
612 roundtrips::<u8>();
613 }
614
615 #[test]
616 fn roundtrips_u16() {
617 roundtrips::<u16>();
618 }
619}