1#![allow(dead_code)]
24#![allow(clippy::similar_names)]
25#![allow(clippy::many_single_char_names)]
26#![allow(clippy::match_same_arms)]
27#![allow(clippy::cast_precision_loss)]
28#![allow(clippy::cast_possible_truncation)]
29#![allow(clippy::no_effect_underscore_binding)]
30#![allow(clippy::needless_range_loop)]
31
32use std::f64::consts::PI;
33
34pub const COS_BIT: u8 = 14;
40
41pub const COS_ROUND: i32 = 1 << (COS_BIT - 1);
43
44pub const TX_COEFF_MAX: i32 = (1 << 15) - 1;
46
47pub const TX_COEFF_MIN: i32 = -(1 << 15);
49
50pub const TX_TYPES: usize = 16;
52
53pub const TX_SIZES: usize = 19;
55
56const fn const_min_u32(a: u32, b: u32) -> u32 {
58 if a < b {
59 a
60 } else {
61 b
62 }
63}
64
65pub const TX_SIZES_SQ: usize = 5;
67
68pub const MAX_TX_SIZE: usize = 64;
70
71pub const MAX_TX_SQUARE: usize = MAX_TX_SIZE * MAX_TX_SIZE;
73
74#[derive(Clone, Copy, Debug, PartialEq, Eq, Default)]
80pub enum TxType1D {
81 #[default]
83 Dct = 0,
84 Adst = 1,
86 FlipAdst = 2,
88 Identity = 3,
90}
91
92impl TxType1D {
93 #[must_use]
95 pub const fn count() -> usize {
96 4
97 }
98
99 #[must_use]
101 pub const fn from_u8(val: u8) -> Option<Self> {
102 match val {
103 0 => Some(Self::Dct),
104 1 => Some(Self::Adst),
105 2 => Some(Self::FlipAdst),
106 3 => Some(Self::Identity),
107 _ => None,
108 }
109 }
110}
111
112#[derive(Clone, Copy, Debug, PartialEq, Eq, Default)]
114pub enum TxType {
115 #[default]
117 DctDct = 0,
118 AdstDct = 1,
120 DctAdst = 2,
122 AdstAdst = 3,
124 FlipAdstDct = 4,
126 DctFlipAdst = 5,
128 FlipAdstAdst = 6,
130 AdstFlipAdst = 7,
132 FlipAdstFlipAdst = 8,
134 IdtxDct = 9,
136 DctIdtx = 10,
138 IdtxAdst = 11,
140 AdstIdtx = 12,
142 IdtxFlipAdst = 13,
144 FlipAdstIdtx = 14,
146 IdtxIdtx = 15,
148}
149
150impl TxType {
151 #[must_use]
153 pub const fn row_type(self) -> TxType1D {
154 match self {
155 Self::DctDct | Self::DctAdst | Self::DctFlipAdst | Self::DctIdtx => TxType1D::Dct,
156 Self::AdstDct | Self::AdstAdst | Self::AdstFlipAdst | Self::AdstIdtx => TxType1D::Adst,
157 Self::FlipAdstDct
158 | Self::FlipAdstAdst
159 | Self::FlipAdstFlipAdst
160 | Self::FlipAdstIdtx => TxType1D::FlipAdst,
161 Self::IdtxDct | Self::IdtxAdst | Self::IdtxFlipAdst | Self::IdtxIdtx => {
162 TxType1D::Identity
163 }
164 }
165 }
166
167 #[must_use]
169 pub const fn col_type(self) -> TxType1D {
170 match self {
171 Self::DctDct | Self::AdstDct | Self::FlipAdstDct | Self::IdtxDct => TxType1D::Dct,
172 Self::DctAdst | Self::AdstAdst | Self::FlipAdstAdst | Self::IdtxAdst => TxType1D::Adst,
173 Self::DctFlipAdst
174 | Self::AdstFlipAdst
175 | Self::FlipAdstFlipAdst
176 | Self::IdtxFlipAdst => TxType1D::FlipAdst,
177 Self::DctIdtx | Self::AdstIdtx | Self::FlipAdstIdtx | Self::IdtxIdtx => {
178 TxType1D::Identity
179 }
180 }
181 }
182
183 #[must_use]
185 pub const fn from_u8(val: u8) -> Option<Self> {
186 match val {
187 0 => Some(Self::DctDct),
188 1 => Some(Self::AdstDct),
189 2 => Some(Self::DctAdst),
190 3 => Some(Self::AdstAdst),
191 4 => Some(Self::FlipAdstDct),
192 5 => Some(Self::DctFlipAdst),
193 6 => Some(Self::FlipAdstAdst),
194 7 => Some(Self::AdstFlipAdst),
195 8 => Some(Self::FlipAdstFlipAdst),
196 9 => Some(Self::IdtxDct),
197 10 => Some(Self::DctIdtx),
198 11 => Some(Self::IdtxAdst),
199 12 => Some(Self::AdstIdtx),
200 13 => Some(Self::IdtxFlipAdst),
201 14 => Some(Self::FlipAdstIdtx),
202 15 => Some(Self::IdtxIdtx),
203 _ => None,
204 }
205 }
206
207 #[must_use]
209 pub const fn is_valid_for_size(self, tx_size: TxSize) -> bool {
210 let has_identity = matches!(
212 self,
213 Self::IdtxDct
214 | Self::DctIdtx
215 | Self::IdtxAdst
216 | Self::AdstIdtx
217 | Self::IdtxFlipAdst
218 | Self::FlipAdstIdtx
219 | Self::IdtxIdtx
220 );
221
222 if has_identity {
223 !matches!(
225 tx_size,
226 TxSize::Tx64x64
227 | TxSize::Tx32x64
228 | TxSize::Tx64x32
229 | TxSize::Tx16x64
230 | TxSize::Tx64x16
231 )
232 } else {
233 true
234 }
235 }
236
237 #[must_use]
239 pub const fn tx_class(self) -> TxClass {
240 match self {
241 Self::DctDct
242 | Self::AdstDct
243 | Self::DctAdst
244 | Self::AdstAdst
245 | Self::FlipAdstDct
246 | Self::DctFlipAdst
247 | Self::FlipAdstAdst
248 | Self::AdstFlipAdst
249 | Self::FlipAdstFlipAdst => TxClass::Class2D,
250 Self::IdtxDct | Self::IdtxAdst | Self::IdtxFlipAdst => TxClass::ClassVert,
251 Self::DctIdtx | Self::AdstIdtx | Self::FlipAdstIdtx | Self::IdtxIdtx => {
252 TxClass::ClassHoriz
253 }
254 }
255 }
256}
257
258#[derive(Clone, Copy, Debug, PartialEq, Eq, Default)]
260pub enum TxClass {
261 #[default]
263 Class2D = 0,
264 ClassHoriz = 1,
266 ClassVert = 2,
268}
269
270impl TxClass {
271 #[must_use]
273 pub const fn count() -> usize {
274 3
275 }
276
277 #[must_use]
279 pub const fn from_u8(val: u8) -> Option<Self> {
280 match val {
281 0 => Some(Self::Class2D),
282 1 => Some(Self::ClassHoriz),
283 2 => Some(Self::ClassVert),
284 _ => None,
285 }
286 }
287}
288
289#[derive(Clone, Copy, Debug, PartialEq, Eq, Default)]
291pub enum TxSize {
292 #[default]
294 Tx4x4 = 0,
295 Tx8x8 = 1,
297 Tx16x16 = 2,
299 Tx32x32 = 3,
301 Tx64x64 = 4,
303 Tx4x8 = 5,
305 Tx8x4 = 6,
307 Tx8x16 = 7,
309 Tx16x8 = 8,
311 Tx16x32 = 9,
313 Tx32x16 = 10,
315 Tx32x64 = 11,
317 Tx64x32 = 12,
319 Tx4x16 = 13,
321 Tx16x4 = 14,
323 Tx8x32 = 15,
325 Tx32x8 = 16,
327 Tx16x64 = 17,
329 Tx64x16 = 18,
331}
332
333impl TxSize {
334 #[must_use]
336 pub const fn width(self) -> u32 {
337 match self {
338 Self::Tx4x4 | Self::Tx4x8 | Self::Tx4x16 => 4,
339 Self::Tx8x8 | Self::Tx8x4 | Self::Tx8x16 | Self::Tx8x32 => 8,
340 Self::Tx16x16 | Self::Tx16x8 | Self::Tx16x32 | Self::Tx16x4 | Self::Tx16x64 => 16,
341 Self::Tx32x32 | Self::Tx32x16 | Self::Tx32x64 | Self::Tx32x8 => 32,
342 Self::Tx64x64 | Self::Tx64x32 | Self::Tx64x16 => 64,
343 }
344 }
345
346 #[must_use]
348 pub const fn height(self) -> u32 {
349 match self {
350 Self::Tx4x4 | Self::Tx8x4 | Self::Tx16x4 => 4,
351 Self::Tx8x8 | Self::Tx4x8 | Self::Tx16x8 | Self::Tx32x8 => 8,
352 Self::Tx16x16 | Self::Tx8x16 | Self::Tx32x16 | Self::Tx4x16 | Self::Tx64x16 => 16,
353 Self::Tx32x32 | Self::Tx16x32 | Self::Tx64x32 | Self::Tx8x32 => 32,
354 Self::Tx64x64 | Self::Tx32x64 | Self::Tx16x64 => 64,
355 }
356 }
357
358 #[must_use]
360 pub const fn width_log2(self) -> u8 {
361 match self.width() {
362 4 => 2,
363 8 => 3,
364 16 => 4,
365 32 => 5,
366 64 => 6,
367 _ => 0,
368 }
369 }
370
371 #[must_use]
373 pub const fn height_log2(self) -> u8 {
374 match self.height() {
375 4 => 2,
376 8 => 3,
377 16 => 4,
378 32 => 5,
379 64 => 6,
380 _ => 0,
381 }
382 }
383
384 #[must_use]
386 pub const fn area(self) -> u32 {
387 self.width() * self.height()
388 }
389
390 #[must_use]
392 pub const fn is_square(self) -> bool {
393 self.width() == self.height()
394 }
395
396 #[must_use]
398 pub const fn sqr_size(self) -> TxSizeSqr {
399 match const_min_u32(self.width(), self.height()) {
400 4 => TxSizeSqr::Tx4x4,
401 8 => TxSizeSqr::Tx8x8,
402 16 => TxSizeSqr::Tx16x16,
403 32 => TxSizeSqr::Tx32x32,
404 64 => TxSizeSqr::Tx64x64,
405 _ => TxSizeSqr::Tx4x4,
406 }
407 }
408
409 #[must_use]
411 pub const fn from_u8(val: u8) -> Option<Self> {
412 match val {
413 0 => Some(Self::Tx4x4),
414 1 => Some(Self::Tx8x8),
415 2 => Some(Self::Tx16x16),
416 3 => Some(Self::Tx32x32),
417 4 => Some(Self::Tx64x64),
418 5 => Some(Self::Tx4x8),
419 6 => Some(Self::Tx8x4),
420 7 => Some(Self::Tx8x16),
421 8 => Some(Self::Tx16x8),
422 9 => Some(Self::Tx16x32),
423 10 => Some(Self::Tx32x16),
424 11 => Some(Self::Tx32x64),
425 12 => Some(Self::Tx64x32),
426 13 => Some(Self::Tx4x16),
427 14 => Some(Self::Tx16x4),
428 15 => Some(Self::Tx8x32),
429 16 => Some(Self::Tx32x8),
430 17 => Some(Self::Tx16x64),
431 18 => Some(Self::Tx64x16),
432 _ => None,
433 }
434 }
435
436 #[must_use]
438 pub const fn from_dimensions(width: u32, height: u32) -> Option<Self> {
439 match (width, height) {
440 (4, 4) => Some(Self::Tx4x4),
441 (8, 8) => Some(Self::Tx8x8),
442 (16, 16) => Some(Self::Tx16x16),
443 (32, 32) => Some(Self::Tx32x32),
444 (64, 64) => Some(Self::Tx64x64),
445 (4, 8) => Some(Self::Tx4x8),
446 (8, 4) => Some(Self::Tx8x4),
447 (8, 16) => Some(Self::Tx8x16),
448 (16, 8) => Some(Self::Tx16x8),
449 (16, 32) => Some(Self::Tx16x32),
450 (32, 16) => Some(Self::Tx32x16),
451 (32, 64) => Some(Self::Tx32x64),
452 (64, 32) => Some(Self::Tx64x32),
453 (4, 16) => Some(Self::Tx4x16),
454 (16, 4) => Some(Self::Tx16x4),
455 (8, 32) => Some(Self::Tx8x32),
456 (32, 8) => Some(Self::Tx32x8),
457 (16, 64) => Some(Self::Tx16x64),
458 (64, 16) => Some(Self::Tx64x16),
459 _ => None,
460 }
461 }
462
463 #[must_use]
465 #[allow(clippy::cast_possible_truncation)]
466 pub const fn max_eob(self) -> u16 {
467 self.area() as u16
468 }
469}
470
471#[derive(Clone, Copy, Debug, PartialEq, Eq, Default)]
473pub enum TxSizeSqr {
474 #[default]
476 Tx4x4 = 0,
477 Tx8x8 = 1,
479 Tx16x16 = 2,
481 Tx32x32 = 3,
483 Tx64x64 = 4,
485}
486
487impl TxSizeSqr {
488 #[must_use]
490 pub const fn size(self) -> u32 {
491 match self {
492 Self::Tx4x4 => 4,
493 Self::Tx8x8 => 8,
494 Self::Tx16x16 => 16,
495 Self::Tx32x32 => 32,
496 Self::Tx64x64 => 64,
497 }
498 }
499
500 #[must_use]
502 pub const fn log2(self) -> u8 {
503 match self {
504 Self::Tx4x4 => 2,
505 Self::Tx8x8 => 3,
506 Self::Tx16x16 => 4,
507 Self::Tx32x32 => 5,
508 Self::Tx64x64 => 6,
509 }
510 }
511}
512
513#[derive(Clone, Debug, Default)]
519pub struct TransformContext {
520 pub tx_size: TxSize,
522 pub tx_type: TxType,
524 pub plane: u8,
526 pub row: u32,
528 pub col: u32,
530 pub skip: bool,
532 pub eob: u16,
534 pub bit_depth: u8,
536 pub lossless: bool,
538}
539
540impl TransformContext {
541 #[must_use]
543 pub const fn new(tx_size: TxSize, tx_type: TxType, plane: u8) -> Self {
544 Self {
545 tx_size,
546 tx_type,
547 plane,
548 row: 0,
549 col: 0,
550 skip: false,
551 eob: 0,
552 bit_depth: 8,
553 lossless: false,
554 }
555 }
556
557 pub fn set_position(&mut self, row: u32, col: u32) {
559 self.row = row;
560 self.col = col;
561 }
562
563 #[must_use]
565 pub const fn tx_class(&self) -> TxClass {
566 self.tx_type.tx_class()
567 }
568
569 #[must_use]
571 pub const fn stride(&self) -> u32 {
572 self.tx_size.width()
573 }
574
575 #[must_use]
577 pub const fn num_coeffs(&self) -> u32 {
578 self.tx_size.area()
579 }
580
581 #[must_use]
583 pub const fn is_luma(&self) -> bool {
584 self.plane == 0
585 }
586
587 #[must_use]
589 pub const fn is_chroma(&self) -> bool {
590 self.plane > 0
591 }
592
593 #[must_use]
596 pub const fn effective_size(&self) -> (u32, u32) {
597 let w = self.tx_size.width();
598 let h = self.tx_size.height();
599 (const_min_u32(w, 32), const_min_u32(h, 32))
600 }
601}
602
603#[must_use]
609#[allow(clippy::cast_possible_truncation)]
610fn cos_value(n: usize, k: usize, size: usize) -> i32 {
611 let angle = PI * (2.0 * k as f64 + 1.0) * n as f64 / (2.0 * size as f64);
612 (angle.cos() * f64::from(1 << COS_BIT)).round() as i32
613}
614
615#[must_use]
617#[allow(clippy::cast_possible_truncation)]
618fn sin_value(n: usize, k: usize, size: usize) -> i32 {
619 let angle = PI * (2.0 * n as f64 + 1.0) * (2.0 * k as f64 + 1.0) / (4.0 * size as f64);
620 (angle.sin() * f64::from(1 << COS_BIT)).round() as i32
621}
622
623#[must_use]
625fn round_shift_sat(value: i64, shift: u8) -> i32 {
626 let shifted = if shift == 0 {
627 value
628 } else {
629 let round = 1i64 << (shift - 1);
630 (value + round) >> shift
631 };
632 shifted.clamp(i64::from(TX_COEFF_MIN), i64::from(TX_COEFF_MAX)) as i32
633}
634
635#[allow(clippy::cast_possible_truncation)]
641pub fn dct4(input: &[i32; 4], output: &mut [i32; 4], cos_bit: u8) {
642 let s0 = input[0] + input[3];
644 let s1 = input[1] + input[2];
645 let s2 = input[1] - input[2];
646 let s3 = input[0] - input[3];
647
648 let cos_k = [
650 cos_value(0, 0, 8), cos_value(1, 0, 8), cos_value(2, 0, 8), cos_value(3, 0, 8), ];
655
656 let t0 = i64::from(s0 + s1) * i64::from(cos_k[0]);
657 let t1 = i64::from(s0 - s1) * i64::from(cos_k[2]);
658 let t2 = i64::from(s2) * i64::from(cos_k[3]) + i64::from(s3) * i64::from(cos_k[1]);
659 let t3 = i64::from(s3) * i64::from(cos_k[3]) - i64::from(s2) * i64::from(cos_k[1]);
660
661 output[0] = round_shift_sat(t0, cos_bit);
662 output[1] = round_shift_sat(t2, cos_bit);
663 output[2] = round_shift_sat(t1, cos_bit);
664 output[3] = round_shift_sat(t3, cos_bit);
665}
666
667#[allow(clippy::cast_possible_truncation, clippy::similar_names)]
669pub fn dct8(input: &[i32; 8], output: &mut [i32; 8], cos_bit: u8) {
670 let s0 = input[0] + input[7];
672 let s1 = input[1] + input[6];
673 let s2 = input[2] + input[5];
674 let s3 = input[3] + input[4];
675 let s4 = input[3] - input[4];
676 let s5 = input[2] - input[5];
677 let s6 = input[1] - input[6];
678 let s7 = input[0] - input[7];
679
680 let even_in = [s0, s1, s2, s3];
682 let mut even_out = [0i32; 4];
683
684 let e0 = even_in[0] + even_in[3];
686 let e1 = even_in[1] + even_in[2];
687 let e2 = even_in[1] - even_in[2];
688 let e3 = even_in[0] - even_in[3];
689
690 even_out[0] = round_shift_sat(i64::from(e0 + e1) * i64::from(cos_value(0, 0, 16)), cos_bit);
691 even_out[2] = round_shift_sat(i64::from(e0 - e1) * i64::from(cos_value(4, 0, 16)), cos_bit);
692 even_out[1] = round_shift_sat(
693 i64::from(e2) * i64::from(cos_value(6, 0, 16))
694 + i64::from(e3) * i64::from(cos_value(2, 0, 16)),
695 cos_bit,
696 );
697 even_out[3] = round_shift_sat(
698 i64::from(e3) * i64::from(cos_value(6, 0, 16))
699 - i64::from(e2) * i64::from(cos_value(2, 0, 16)),
700 cos_bit,
701 );
702
703 let cos1 = cos_value(1, 0, 16);
705 let cos3 = cos_value(3, 0, 16);
706 let cos5 = cos_value(5, 0, 16);
707 let cos7 = cos_value(7, 0, 16);
708
709 let o0 = round_shift_sat(
710 i64::from(s4) * i64::from(cos7) + i64::from(s7) * i64::from(cos1),
711 cos_bit,
712 );
713 let o1 = round_shift_sat(
714 i64::from(s5) * i64::from(cos5) + i64::from(s6) * i64::from(cos3),
715 cos_bit,
716 );
717 let o2 = round_shift_sat(
718 i64::from(s6) * i64::from(cos5) - i64::from(s5) * i64::from(cos3),
719 cos_bit,
720 );
721 let o3 = round_shift_sat(
722 i64::from(s7) * i64::from(cos7) - i64::from(s4) * i64::from(cos1),
723 cos_bit,
724 );
725
726 output[0] = even_out[0];
728 output[1] = o0;
729 output[2] = even_out[1];
730 output[3] = o1;
731 output[4] = even_out[2];
732 output[5] = o2;
733 output[6] = even_out[3];
734 output[7] = o3;
735}
736
737#[allow(clippy::cast_possible_truncation)]
739pub fn dct16(input: &[i32; 16], output: &mut [i32; 16], cos_bit: u8) {
740 let mut even = [0i32; 8];
742 let mut odd = [0i32; 8];
743
744 for i in 0..8 {
746 even[i] = input[i] + input[15 - i];
747 odd[i] = input[i] - input[15 - i];
748 }
749
750 let mut even_out = [0i32; 8];
752 dct8(&even, &mut even_out, cos_bit);
753
754 for i in 0..8 {
756 let cos_idx = 2 * i + 1;
757 let cos_val = cos_value(cos_idx, 0, 32);
758 output[2 * i + 1] = round_shift_sat(i64::from(odd[i]) * i64::from(cos_val), cos_bit);
759 }
760
761 for i in 0..8 {
763 output[2 * i] = even_out[i];
764 }
765}
766
767pub fn dct32(input: &[i32; 32], output: &mut [i32; 32], cos_bit: u8) {
769 let mut even = [0i32; 16];
771 let mut odd = [0i32; 16];
772
773 for i in 0..16 {
774 even[i] = input[i] + input[31 - i];
775 odd[i] = input[i] - input[31 - i];
776 }
777
778 let mut even_out = [0i32; 16];
779 dct16(&even, &mut even_out, cos_bit);
780
781 for i in 0..16 {
782 let cos_idx = 2 * i + 1;
783 let cos_val = cos_value(cos_idx, 0, 64);
784 output[2 * i + 1] = round_shift_sat(i64::from(odd[i]) * i64::from(cos_val), cos_bit);
785 }
786
787 for i in 0..16 {
788 output[2 * i] = even_out[i];
789 }
790}
791
792pub fn dct64(input: &[i32; 64], output: &mut [i32; 64], cos_bit: u8) {
794 let mut even = [0i32; 32];
796 let mut odd = [0i32; 32];
797
798 for i in 0..32 {
799 even[i] = input[i] + input[63 - i];
800 odd[i] = input[i] - input[63 - i];
801 }
802
803 let mut even_out = [0i32; 32];
804 dct32(&even, &mut even_out, cos_bit);
805
806 for i in 0..32 {
807 let cos_idx = 2 * i + 1;
808 let cos_val = cos_value(cos_idx, 0, 128);
809 output[2 * i + 1] = round_shift_sat(i64::from(odd[i]) * i64::from(cos_val), cos_bit);
810 }
811
812 for i in 0..32 {
813 output[2 * i] = even_out[i];
814 }
815}
816
817#[allow(clippy::cast_possible_truncation)]
823pub fn adst4(input: &[i32; 4], output: &mut [i32; 4], cos_bit: u8) {
824 let sin_pi_9 = sin_value(0, 0, 9);
826 let sin_2pi_9 = sin_value(1, 0, 9);
827 let sin_3pi_9 = sin_value(2, 0, 9);
828 let sin_4pi_9 = sin_value(3, 0, 9);
829
830 let s0 = i64::from(input[0]) * i64::from(sin_pi_9);
831 let s1 = i64::from(input[0]) * i64::from(sin_2pi_9);
832 let s2 = i64::from(input[1]) * i64::from(sin_3pi_9);
833 let s3 = i64::from(input[2]) * i64::from(sin_4pi_9);
834 let s4 = i64::from(input[2]) * i64::from(sin_pi_9);
835 let s5 = i64::from(input[3]) * i64::from(sin_2pi_9);
836 let s6 = i64::from(input[3]) * i64::from(sin_4pi_9);
837
838 let t0 = s0 + s3 + s5;
839 let t1 = s1 + s2 - s6;
840 let t2 = s1 - s2 + s6;
841 let t3 = s0 - s3 + s4;
842
843 output[0] = round_shift_sat(t0, cos_bit);
844 output[1] = round_shift_sat(t1, cos_bit);
845 output[2] = round_shift_sat(t2, cos_bit);
846 output[3] = round_shift_sat(t3, cos_bit);
847}
848
849pub fn adst8(input: &[i32; 8], output: &mut [i32; 8], cos_bit: u8) {
851 for (i, out) in output.iter_mut().enumerate() {
853 let mut sum = 0i64;
854 for (j, &inp) in input.iter().enumerate() {
855 let sin_val = sin_value(i, j, 8);
856 sum += i64::from(inp) * i64::from(sin_val);
857 }
858 *out = round_shift_sat(sum, cos_bit);
859 }
860}
861
862pub fn adst16(input: &[i32; 16], output: &mut [i32; 16], cos_bit: u8) {
864 for (i, out) in output.iter_mut().enumerate() {
866 let mut sum = 0i64;
867 for (j, &inp) in input.iter().enumerate() {
868 let sin_val = sin_value(i, j, 16);
869 sum += i64::from(inp) * i64::from(sin_val);
870 }
871 *out = round_shift_sat(sum, cos_bit);
872 }
873}
874
875pub fn identity4(input: &[i32; 4], output: &mut [i32; 4]) {
881 for (i, &val) in input.iter().enumerate() {
883 output[i] = val * 2; }
885}
886
887pub fn identity8(input: &[i32; 8], output: &mut [i32; 8]) {
889 for (i, &val) in input.iter().enumerate() {
890 output[i] = val * 2;
891 }
892}
893
894pub fn identity16(input: &[i32; 16], output: &mut [i32; 16]) {
896 for (i, &val) in input.iter().enumerate() {
897 output[i] = val * 2;
898 }
899}
900
901pub fn identity32(input: &[i32; 32], output: &mut [i32; 32]) {
903 for (i, &val) in input.iter().enumerate() {
904 output[i] = val * 4; }
906}
907
908pub fn idct4(input: &[i32; 4], output: &mut [i32; 4], cos_bit: u8) {
914 let cos0 = cos_value(0, 0, 8);
917 let cos1 = cos_value(1, 0, 8);
918 let cos2 = cos_value(2, 0, 8);
919 let cos3 = cos_value(3, 0, 8);
920
921 let t0 = i64::from(input[0]) * i64::from(cos0);
923 let t1 = i64::from(input[2]) * i64::from(cos2);
924 let t2 = i64::from(input[1]) * i64::from(cos1) + i64::from(input[3]) * i64::from(cos3);
925 let t3 = i64::from(input[1]) * i64::from(cos3) - i64::from(input[3]) * i64::from(cos1);
926
927 let s0 = round_shift_sat(t0 + t1, cos_bit);
928 let s1 = round_shift_sat(t0 - t1, cos_bit);
929 let s2 = round_shift_sat(t2, cos_bit);
930 let s3 = round_shift_sat(t3, cos_bit);
931
932 output[0] = s0 + s2;
934 output[1] = s1 + s3;
935 output[2] = s1 - s3;
936 output[3] = s0 - s2;
937}
938
939pub fn idct8(input: &[i32; 8], output: &mut [i32; 8], cos_bit: u8) {
941 for (i, out) in output.iter_mut().enumerate() {
943 let mut sum = 0i64;
944 for (j, &inp) in input.iter().enumerate() {
945 let cos_val = cos_value(j, i, 8);
946 sum += i64::from(inp) * i64::from(cos_val);
947 }
948 *out = round_shift_sat(sum, cos_bit);
949 }
950}
951
952pub fn idct16(input: &[i32; 16], output: &mut [i32; 16], cos_bit: u8) {
954 for (i, out) in output.iter_mut().enumerate() {
955 let mut sum = 0i64;
956 for (j, &inp) in input.iter().enumerate() {
957 let cos_val = cos_value(j, i, 16);
958 sum += i64::from(inp) * i64::from(cos_val);
959 }
960 *out = round_shift_sat(sum, cos_bit);
961 }
962}
963
964pub fn idct32(input: &[i32; 32], output: &mut [i32; 32], cos_bit: u8) {
966 for (i, out) in output.iter_mut().enumerate() {
967 let mut sum = 0i64;
968 for (j, &inp) in input.iter().enumerate() {
969 let cos_val = cos_value(j, i, 32);
970 sum += i64::from(inp) * i64::from(cos_val);
971 }
972 *out = round_shift_sat(sum, cos_bit);
973 }
974}
975
976pub fn idct64(input: &[i32; 64], output: &mut [i32; 64], cos_bit: u8) {
978 for (i, out) in output.iter_mut().enumerate() {
979 let mut sum = 0i64;
980 for (j, &inp) in input.iter().enumerate() {
981 let cos_val = cos_value(j, i, 64);
982 sum += i64::from(inp) * i64::from(cos_val);
983 }
984 *out = round_shift_sat(sum, cos_bit);
985 }
986}
987
988pub fn iadst4(input: &[i32; 4], output: &mut [i32; 4], cos_bit: u8) {
990 adst4(input, output, cos_bit);
992}
993
994pub fn iadst8(input: &[i32; 8], output: &mut [i32; 8], cos_bit: u8) {
996 adst8(input, output, cos_bit);
997}
998
999pub fn iadst16(input: &[i32; 16], output: &mut [i32; 16], cos_bit: u8) {
1001 adst16(input, output, cos_bit);
1002}
1003
1004#[derive(Clone, Debug)]
1010pub struct Transform2D {
1011 buffer: Vec<i32>,
1013 tx_size: TxSize,
1015 tx_type: TxType,
1017}
1018
1019impl Transform2D {
1020 #[must_use]
1022 pub fn new(tx_size: TxSize, tx_type: TxType) -> Self {
1023 let area = tx_size.area() as usize;
1024 Self {
1025 buffer: vec![0; area],
1026 tx_size,
1027 tx_type,
1028 }
1029 }
1030
1031 pub fn inverse(&mut self, input: &[i32], output: &mut [i32]) {
1033 let width = self.tx_size.width() as usize;
1034 let height = self.tx_size.height() as usize;
1035 let _cos_bit = COS_BIT;
1036
1037 for row in 0..height {
1039 let row_start = row * width;
1040 self.apply_row_inverse(&input[row_start..row_start + width], row);
1041 }
1042
1043 for col in 0..width {
1045 self.apply_col_inverse(col, &mut output[col..], width);
1046 }
1047 }
1048
1049 fn apply_row_inverse(&mut self, input: &[i32], row: usize) {
1051 let width = self.tx_size.width() as usize;
1052 let row_type = self.tx_type.row_type();
1053 let cos_bit = COS_BIT;
1054
1055 let mut row_out = vec![0i32; width];
1057
1058 match (row_type, width) {
1059 (TxType1D::Dct, 4) => {
1060 let mut inp = [0i32; 4];
1061 let mut out = [0i32; 4];
1062 inp.copy_from_slice(input);
1063 idct4(&inp, &mut out, cos_bit);
1064 row_out.copy_from_slice(&out);
1065 }
1066 (TxType1D::Dct, 8) => {
1067 let mut inp = [0i32; 8];
1068 let mut out = [0i32; 8];
1069 inp.copy_from_slice(input);
1070 idct8(&inp, &mut out, cos_bit);
1071 row_out.copy_from_slice(&out);
1072 }
1073 (TxType1D::Adst, 4) => {
1074 let mut inp = [0i32; 4];
1075 let mut out = [0i32; 4];
1076 inp.copy_from_slice(input);
1077 iadst4(&inp, &mut out, cos_bit);
1078 row_out.copy_from_slice(&out);
1079 }
1080 (TxType1D::Adst, 8) => {
1081 let mut inp = [0i32; 8];
1082 let mut out = [0i32; 8];
1083 inp.copy_from_slice(input);
1084 iadst8(&inp, &mut out, cos_bit);
1085 row_out.copy_from_slice(&out);
1086 }
1087 (TxType1D::Identity, 4) => {
1088 let mut inp = [0i32; 4];
1089 let mut out = [0i32; 4];
1090 inp.copy_from_slice(input);
1091 identity4(&inp, &mut out);
1092 row_out.copy_from_slice(&out);
1093 }
1094 (TxType1D::Identity, 8) => {
1095 let mut inp = [0i32; 8];
1096 let mut out = [0i32; 8];
1097 inp.copy_from_slice(input);
1098 identity8(&inp, &mut out);
1099 row_out.copy_from_slice(&out);
1100 }
1101 (TxType1D::FlipAdst, n) => {
1102 let mut temp = vec![0i32; n];
1104 match n {
1105 4 => {
1106 let mut inp = [0i32; 4];
1107 let mut out = [0i32; 4];
1108 inp.copy_from_slice(input);
1109 iadst4(&inp, &mut out, cos_bit);
1110 temp.copy_from_slice(&out);
1111 }
1112 8 => {
1113 let mut inp = [0i32; 8];
1114 let mut out = [0i32; 8];
1115 inp.copy_from_slice(input);
1116 iadst8(&inp, &mut out, cos_bit);
1117 temp.copy_from_slice(&out);
1118 }
1119 _ => temp.copy_from_slice(input),
1120 }
1121 for i in 0..n {
1122 row_out[i] = temp[n - 1 - i];
1123 }
1124 }
1125 _ => {
1126 row_out[..width].copy_from_slice(&input[..width]);
1128 }
1129 }
1130
1131 let row_start = row * width;
1133 self.buffer[row_start..row_start + width].copy_from_slice(&row_out);
1134 }
1135
1136 fn apply_col_inverse(&self, col: usize, output: &mut [i32], stride: usize) {
1138 let width = self.tx_size.width() as usize;
1139 let height = self.tx_size.height() as usize;
1140 let col_type = self.tx_type.col_type();
1141 let cos_bit = COS_BIT;
1142
1143 let mut col_in = vec![0i32; height];
1145 for row in 0..height {
1146 col_in[row] = self.buffer[row * width + col];
1147 }
1148
1149 let mut col_out = vec![0i32; height];
1150
1151 match (col_type, height) {
1152 (TxType1D::Dct, 4) => {
1153 let mut inp = [0i32; 4];
1154 let mut out = [0i32; 4];
1155 inp.copy_from_slice(&col_in);
1156 idct4(&inp, &mut out, cos_bit);
1157 col_out.copy_from_slice(&out);
1158 }
1159 (TxType1D::Dct, 8) => {
1160 let mut inp = [0i32; 8];
1161 let mut out = [0i32; 8];
1162 inp.copy_from_slice(&col_in);
1163 idct8(&inp, &mut out, cos_bit);
1164 col_out.copy_from_slice(&out);
1165 }
1166 (TxType1D::Adst, 4) => {
1167 let mut inp = [0i32; 4];
1168 let mut out = [0i32; 4];
1169 inp.copy_from_slice(&col_in);
1170 iadst4(&inp, &mut out, cos_bit);
1171 col_out.copy_from_slice(&out);
1172 }
1173 (TxType1D::Adst, 8) => {
1174 let mut inp = [0i32; 8];
1175 let mut out = [0i32; 8];
1176 inp.copy_from_slice(&col_in);
1177 iadst8(&inp, &mut out, cos_bit);
1178 col_out.copy_from_slice(&out);
1179 }
1180 (TxType1D::Identity, 4) => {
1181 let mut inp = [0i32; 4];
1182 let mut out = [0i32; 4];
1183 inp.copy_from_slice(&col_in);
1184 identity4(&inp, &mut out);
1185 col_out.copy_from_slice(&out);
1186 }
1187 (TxType1D::Identity, 8) => {
1188 let mut inp = [0i32; 8];
1189 let mut out = [0i32; 8];
1190 inp.copy_from_slice(&col_in);
1191 identity8(&inp, &mut out);
1192 col_out.copy_from_slice(&out);
1193 }
1194 (TxType1D::FlipAdst, n) => {
1195 let mut temp = vec![0i32; n];
1196 match n {
1197 4 => {
1198 let mut inp = [0i32; 4];
1199 let mut out = [0i32; 4];
1200 inp.copy_from_slice(&col_in);
1201 iadst4(&inp, &mut out, cos_bit);
1202 temp.copy_from_slice(&out);
1203 }
1204 8 => {
1205 let mut inp = [0i32; 8];
1206 let mut out = [0i32; 8];
1207 inp.copy_from_slice(&col_in);
1208 iadst8(&inp, &mut out, cos_bit);
1209 temp.copy_from_slice(&out);
1210 }
1211 _ => temp.copy_from_slice(&col_in),
1212 }
1213 for i in 0..n {
1214 col_out[i] = temp[n - 1 - i];
1215 }
1216 }
1217 _ => {
1218 col_out.copy_from_slice(&col_in);
1219 }
1220 }
1221
1222 for row in 0..height {
1224 output[row * stride] = col_out[row];
1225 }
1226 }
1227}
1228
1229pub fn flip_horizontal(coeffs: &mut [i32], width: usize, height: usize) {
1235 for row in 0..height {
1236 let row_start = row * width;
1237 coeffs[row_start..row_start + width].reverse();
1238 }
1239}
1240
1241pub fn flip_vertical(coeffs: &mut [i32], width: usize, height: usize) {
1243 for col in 0..width {
1244 for row in 0..height / 2 {
1245 let top = row * width + col;
1246 let bottom = (height - 1 - row) * width + col;
1247 coeffs.swap(top, bottom);
1248 }
1249 }
1250}
1251
1252pub fn wht4x4(input: &[i32; 16], output: &mut [i32; 16]) {
1258 for (i, &val) in input.iter().enumerate() {
1260 output[i] = val;
1261 }
1262
1263 for row in 0..4 {
1265 let i = row * 4;
1266 let a = output[i] + output[i + 1];
1267 let b = output[i + 2] + output[i + 3];
1268 let c = output[i] - output[i + 1];
1269 let d = output[i + 2] - output[i + 3];
1270
1271 output[i] = a + b;
1272 output[i + 1] = c + d;
1273 output[i + 2] = a - b;
1274 output[i + 3] = c - d;
1275 }
1276
1277 for col in 0..4 {
1279 let a = output[col] + output[col + 4];
1280 let b = output[col + 8] + output[col + 12];
1281 let c = output[col] - output[col + 4];
1282 let d = output[col + 8] - output[col + 12];
1283
1284 output[col] = (a + b) >> 2;
1285 output[col + 4] = (c + d) >> 2;
1286 output[col + 8] = (a - b) >> 2;
1287 output[col + 12] = (c - d) >> 2;
1288 }
1289}
1290
1291pub fn iwht4x4(input: &[i32; 16], output: &mut [i32; 16]) {
1293 wht4x4(input, output);
1295}
1296
1297#[must_use]
1304pub const fn get_reduced_tx_size(tx_size: TxSize) -> (u32, u32) {
1305 let width = tx_size.width();
1306 let height = tx_size.height();
1307 (const_min_u32(width, 32), const_min_u32(height, 32))
1308}
1309
1310#[must_use]
1312pub const fn needs_reduction(tx_size: TxSize) -> bool {
1313 tx_size.width() > 32 || tx_size.height() > 32
1314}
1315
1316#[must_use]
1318pub const fn get_max_nonzero_coeffs(tx_size: TxSize) -> u32 {
1319 let (w, h) = get_reduced_tx_size(tx_size);
1320 w * h
1321}
1322
1323#[cfg(test)]
1328mod tests {
1329 use super::*;
1330
1331 #[test]
1332 fn test_tx_type_components() {
1333 assert_eq!(TxType::DctDct.row_type(), TxType1D::Dct);
1334 assert_eq!(TxType::DctDct.col_type(), TxType1D::Dct);
1335
1336 assert_eq!(TxType::AdstDct.row_type(), TxType1D::Adst);
1337 assert_eq!(TxType::AdstDct.col_type(), TxType1D::Dct);
1338
1339 assert_eq!(TxType::IdtxIdtx.row_type(), TxType1D::Identity);
1340 assert_eq!(TxType::IdtxIdtx.col_type(), TxType1D::Identity);
1341 }
1342
1343 #[test]
1344 fn test_tx_size_dimensions() {
1345 assert_eq!(TxSize::Tx4x4.width(), 4);
1346 assert_eq!(TxSize::Tx4x4.height(), 4);
1347
1348 assert_eq!(TxSize::Tx4x8.width(), 4);
1349 assert_eq!(TxSize::Tx4x8.height(), 8);
1350
1351 assert_eq!(TxSize::Tx64x64.width(), 64);
1352 assert_eq!(TxSize::Tx64x64.height(), 64);
1353 }
1354
1355 #[test]
1356 fn test_tx_size_log2() {
1357 assert_eq!(TxSize::Tx4x4.width_log2(), 2);
1358 assert_eq!(TxSize::Tx8x8.width_log2(), 3);
1359 assert_eq!(TxSize::Tx16x16.width_log2(), 4);
1360 assert_eq!(TxSize::Tx32x32.width_log2(), 5);
1361 assert_eq!(TxSize::Tx64x64.width_log2(), 6);
1362 }
1363
1364 #[test]
1365 fn test_tx_size_area() {
1366 assert_eq!(TxSize::Tx4x4.area(), 16);
1367 assert_eq!(TxSize::Tx8x8.area(), 64);
1368 assert_eq!(TxSize::Tx4x8.area(), 32);
1369 }
1370
1371 #[test]
1372 fn test_tx_size_is_square() {
1373 assert!(TxSize::Tx4x4.is_square());
1374 assert!(TxSize::Tx8x8.is_square());
1375 assert!(!TxSize::Tx4x8.is_square());
1376 assert!(!TxSize::Tx8x4.is_square());
1377 }
1378
1379 #[test]
1380 fn test_tx_class() {
1381 assert_eq!(TxType::DctDct.tx_class(), TxClass::Class2D);
1382 assert_eq!(TxType::IdtxDct.tx_class(), TxClass::ClassVert);
1383 assert_eq!(TxType::DctIdtx.tx_class(), TxClass::ClassHoriz);
1384 }
1385
1386 #[test]
1387 fn test_tx_type_from_u8() {
1388 assert_eq!(TxType::from_u8(0), Some(TxType::DctDct));
1389 assert_eq!(TxType::from_u8(15), Some(TxType::IdtxIdtx));
1390 assert_eq!(TxType::from_u8(16), None);
1391 }
1392
1393 #[test]
1394 fn test_tx_size_from_u8() {
1395 assert_eq!(TxSize::from_u8(0), Some(TxSize::Tx4x4));
1396 assert_eq!(TxSize::from_u8(18), Some(TxSize::Tx64x16));
1397 assert_eq!(TxSize::from_u8(19), None);
1398 }
1399
1400 #[test]
1401 fn test_tx_size_from_dimensions() {
1402 assert_eq!(TxSize::from_dimensions(4, 4), Some(TxSize::Tx4x4));
1403 assert_eq!(TxSize::from_dimensions(64, 64), Some(TxSize::Tx64x64));
1404 assert_eq!(TxSize::from_dimensions(4, 8), Some(TxSize::Tx4x8));
1405 assert_eq!(TxSize::from_dimensions(3, 3), None);
1406 }
1407
1408 #[test]
1409 fn test_transform_context() {
1410 let ctx = TransformContext::new(TxSize::Tx8x8, TxType::DctDct, 0);
1411 assert_eq!(ctx.stride(), 8);
1412 assert_eq!(ctx.num_coeffs(), 64);
1413 assert!(ctx.is_luma());
1414 assert!(!ctx.is_chroma());
1415 }
1416
1417 #[test]
1418 fn test_dct4_identity() {
1419 let input = [1, 1, 1, 1];
1421 let mut output = [0i32; 4];
1422 dct4(&input, &mut output, COS_BIT);
1423 assert!(output[0].abs() > output[1].abs());
1425 }
1426
1427 #[test]
1428 fn test_idct4_dct4_roundtrip() {
1429 let input = [100, 50, -30, 80];
1430 let mut dct_out = [0i32; 4];
1431 let mut idct_out = [0i32; 4];
1432
1433 dct4(&input, &mut dct_out, COS_BIT);
1434 idct4(&dct_out, &mut idct_out, COS_BIT);
1435
1436 for i in 0..4 {
1439 let diff = (input[i] - idct_out[i]).abs();
1440 assert!(diff < 500, "Roundtrip error too large at {i}: {diff}");
1442 }
1443 }
1444
1445 #[test]
1446 fn test_identity_transform() {
1447 let input = [1, 2, 3, 4];
1448 let mut output = [0i32; 4];
1449 identity4(&input, &mut output);
1450
1451 for i in 0..4 {
1453 assert_eq!(output[i], input[i] * 2);
1454 }
1455 }
1456
1457 #[test]
1458 fn test_wht4x4() {
1459 let input = [1i32; 16];
1460 let mut output = [0i32; 16];
1461 wht4x4(&input, &mut output);
1462
1463 assert_ne!(output[0], 0);
1465 }
1466
1467 #[test]
1468 fn test_reduced_tx_size() {
1469 assert_eq!(get_reduced_tx_size(TxSize::Tx4x4), (4, 4));
1470 assert_eq!(get_reduced_tx_size(TxSize::Tx64x64), (32, 32));
1471 assert_eq!(get_reduced_tx_size(TxSize::Tx64x32), (32, 32));
1472 }
1473
1474 #[test]
1475 fn test_needs_reduction() {
1476 assert!(!needs_reduction(TxSize::Tx32x32));
1477 assert!(needs_reduction(TxSize::Tx64x64));
1478 assert!(needs_reduction(TxSize::Tx64x32));
1479 }
1480
1481 #[test]
1482 fn test_max_nonzero_coeffs() {
1483 assert_eq!(get_max_nonzero_coeffs(TxSize::Tx4x4), 16);
1484 assert_eq!(get_max_nonzero_coeffs(TxSize::Tx64x64), 1024); }
1486
1487 #[test]
1488 fn test_tx_type_valid_for_size() {
1489 assert!(TxType::DctDct.is_valid_for_size(TxSize::Tx64x64));
1490 assert!(!TxType::IdtxIdtx.is_valid_for_size(TxSize::Tx64x64));
1491 assert!(TxType::IdtxIdtx.is_valid_for_size(TxSize::Tx32x32));
1492 }
1493
1494 #[test]
1495 fn test_flip_horizontal() {
1496 let mut coeffs = [1, 2, 3, 4, 5, 6, 7, 8];
1497 flip_horizontal(&mut coeffs, 4, 2);
1498 assert_eq!(coeffs, [4, 3, 2, 1, 8, 7, 6, 5]);
1499 }
1500
1501 #[test]
1502 fn test_flip_vertical() {
1503 let mut coeffs = [1, 2, 3, 4, 5, 6, 7, 8];
1504 flip_vertical(&mut coeffs, 4, 2);
1505 assert_eq!(coeffs, [5, 6, 7, 8, 1, 2, 3, 4]);
1506 }
1507
1508 #[test]
1509 fn test_transform_2d_new() {
1510 let tx = Transform2D::new(TxSize::Tx8x8, TxType::DctDct);
1511 assert_eq!(tx.buffer.len(), 64);
1512 }
1513
1514 #[test]
1515 fn test_tx_size_sqr() {
1516 assert_eq!(TxSizeSqr::Tx4x4.size(), 4);
1517 assert_eq!(TxSizeSqr::Tx8x8.size(), 8);
1518 assert_eq!(TxSizeSqr::Tx64x64.log2(), 6);
1519 }
1520
1521 #[test]
1522 fn test_cos_value() {
1523 let cos0 = cos_value(0, 0, 8);
1525 assert!(cos0 > 0);
1526
1527 let cos_half_pi = cos_value(4, 0, 8);
1529 assert!(cos_half_pi.abs() < cos0);
1530 }
1531
1532 #[test]
1533 fn test_round_shift_sat() {
1534 assert_eq!(round_shift_sat(100, 2), 25);
1535 assert_eq!(round_shift_sat(100, 1), 50);
1536 let max_plus = i64::from(TX_COEFF_MAX) * 4; assert_eq!(round_shift_sat(max_plus, 1), TX_COEFF_MAX);
1540 let min_minus = i64::from(TX_COEFF_MIN) * 4; assert_eq!(round_shift_sat(min_minus, 1), TX_COEFF_MIN);
1543 }
1544
1545 #[test]
1546 fn test_constants() {
1547 assert_eq!(TX_TYPES, 16);
1548 assert_eq!(TX_SIZES, 19);
1549 assert_eq!(MAX_TX_SIZE, 64);
1550 assert_eq!(MAX_TX_SQUARE, 4096);
1551 }
1552}