#![allow(dead_code)]
#![allow(clippy::similar_names)]
#![allow(clippy::many_single_char_names)]
#![allow(clippy::match_same_arms)]
#![allow(clippy::cast_precision_loss)]
#![allow(clippy::cast_possible_truncation)]
#![allow(clippy::no_effect_underscore_binding)]
#![allow(clippy::needless_range_loop)]
use std::f64::consts::PI;
pub const COS_BIT: u8 = 14;
pub const COS_ROUND: i32 = 1 << (COS_BIT - 1);
pub const TX_COEFF_MAX: i32 = (1 << 15) - 1;
pub const TX_COEFF_MIN: i32 = -(1 << 15);
pub const TX_TYPES: usize = 16;
pub const TX_SIZES: usize = 19;
const fn const_min_u32(a: u32, b: u32) -> u32 {
if a < b {
a
} else {
b
}
}
pub const TX_SIZES_SQ: usize = 5;
pub const MAX_TX_SIZE: usize = 64;
pub const MAX_TX_SQUARE: usize = MAX_TX_SIZE * MAX_TX_SIZE;
#[derive(Clone, Copy, Debug, PartialEq, Eq, Default)]
pub enum TxType1D {
#[default]
Dct = 0,
Adst = 1,
FlipAdst = 2,
Identity = 3,
}
impl TxType1D {
#[must_use]
pub const fn count() -> usize {
4
}
#[must_use]
pub const fn from_u8(val: u8) -> Option<Self> {
match val {
0 => Some(Self::Dct),
1 => Some(Self::Adst),
2 => Some(Self::FlipAdst),
3 => Some(Self::Identity),
_ => None,
}
}
}
#[derive(Clone, Copy, Debug, PartialEq, Eq, Default)]
pub enum TxType {
#[default]
DctDct = 0,
AdstDct = 1,
DctAdst = 2,
AdstAdst = 3,
FlipAdstDct = 4,
DctFlipAdst = 5,
FlipAdstAdst = 6,
AdstFlipAdst = 7,
FlipAdstFlipAdst = 8,
IdtxDct = 9,
DctIdtx = 10,
IdtxAdst = 11,
AdstIdtx = 12,
IdtxFlipAdst = 13,
FlipAdstIdtx = 14,
IdtxIdtx = 15,
}
impl TxType {
#[must_use]
pub const fn row_type(self) -> TxType1D {
match self {
Self::DctDct | Self::DctAdst | Self::DctFlipAdst | Self::DctIdtx => TxType1D::Dct,
Self::AdstDct | Self::AdstAdst | Self::AdstFlipAdst | Self::AdstIdtx => TxType1D::Adst,
Self::FlipAdstDct
| Self::FlipAdstAdst
| Self::FlipAdstFlipAdst
| Self::FlipAdstIdtx => TxType1D::FlipAdst,
Self::IdtxDct | Self::IdtxAdst | Self::IdtxFlipAdst | Self::IdtxIdtx => {
TxType1D::Identity
}
}
}
#[must_use]
pub const fn col_type(self) -> TxType1D {
match self {
Self::DctDct | Self::AdstDct | Self::FlipAdstDct | Self::IdtxDct => TxType1D::Dct,
Self::DctAdst | Self::AdstAdst | Self::FlipAdstAdst | Self::IdtxAdst => TxType1D::Adst,
Self::DctFlipAdst
| Self::AdstFlipAdst
| Self::FlipAdstFlipAdst
| Self::IdtxFlipAdst => TxType1D::FlipAdst,
Self::DctIdtx | Self::AdstIdtx | Self::FlipAdstIdtx | Self::IdtxIdtx => {
TxType1D::Identity
}
}
}
#[must_use]
pub const fn from_u8(val: u8) -> Option<Self> {
match val {
0 => Some(Self::DctDct),
1 => Some(Self::AdstDct),
2 => Some(Self::DctAdst),
3 => Some(Self::AdstAdst),
4 => Some(Self::FlipAdstDct),
5 => Some(Self::DctFlipAdst),
6 => Some(Self::FlipAdstAdst),
7 => Some(Self::AdstFlipAdst),
8 => Some(Self::FlipAdstFlipAdst),
9 => Some(Self::IdtxDct),
10 => Some(Self::DctIdtx),
11 => Some(Self::IdtxAdst),
12 => Some(Self::AdstIdtx),
13 => Some(Self::IdtxFlipAdst),
14 => Some(Self::FlipAdstIdtx),
15 => Some(Self::IdtxIdtx),
_ => None,
}
}
#[must_use]
pub const fn is_valid_for_size(self, tx_size: TxSize) -> bool {
let has_identity = matches!(
self,
Self::IdtxDct
| Self::DctIdtx
| Self::IdtxAdst
| Self::AdstIdtx
| Self::IdtxFlipAdst
| Self::FlipAdstIdtx
| Self::IdtxIdtx
);
if has_identity {
!matches!(
tx_size,
TxSize::Tx64x64
| TxSize::Tx32x64
| TxSize::Tx64x32
| TxSize::Tx16x64
| TxSize::Tx64x16
)
} else {
true
}
}
#[must_use]
pub const fn tx_class(self) -> TxClass {
match self {
Self::DctDct
| Self::AdstDct
| Self::DctAdst
| Self::AdstAdst
| Self::FlipAdstDct
| Self::DctFlipAdst
| Self::FlipAdstAdst
| Self::AdstFlipAdst
| Self::FlipAdstFlipAdst => TxClass::Class2D,
Self::IdtxDct | Self::IdtxAdst | Self::IdtxFlipAdst => TxClass::ClassVert,
Self::DctIdtx | Self::AdstIdtx | Self::FlipAdstIdtx | Self::IdtxIdtx => {
TxClass::ClassHoriz
}
}
}
}
#[derive(Clone, Copy, Debug, PartialEq, Eq, Default)]
pub enum TxClass {
#[default]
Class2D = 0,
ClassHoriz = 1,
ClassVert = 2,
}
impl TxClass {
#[must_use]
pub const fn count() -> usize {
3
}
#[must_use]
pub const fn from_u8(val: u8) -> Option<Self> {
match val {
0 => Some(Self::Class2D),
1 => Some(Self::ClassHoriz),
2 => Some(Self::ClassVert),
_ => None,
}
}
}
#[derive(Clone, Copy, Debug, PartialEq, Eq, Default)]
pub enum TxSize {
#[default]
Tx4x4 = 0,
Tx8x8 = 1,
Tx16x16 = 2,
Tx32x32 = 3,
Tx64x64 = 4,
Tx4x8 = 5,
Tx8x4 = 6,
Tx8x16 = 7,
Tx16x8 = 8,
Tx16x32 = 9,
Tx32x16 = 10,
Tx32x64 = 11,
Tx64x32 = 12,
Tx4x16 = 13,
Tx16x4 = 14,
Tx8x32 = 15,
Tx32x8 = 16,
Tx16x64 = 17,
Tx64x16 = 18,
}
impl TxSize {
#[must_use]
pub const fn width(self) -> u32 {
match self {
Self::Tx4x4 | Self::Tx4x8 | Self::Tx4x16 => 4,
Self::Tx8x8 | Self::Tx8x4 | Self::Tx8x16 | Self::Tx8x32 => 8,
Self::Tx16x16 | Self::Tx16x8 | Self::Tx16x32 | Self::Tx16x4 | Self::Tx16x64 => 16,
Self::Tx32x32 | Self::Tx32x16 | Self::Tx32x64 | Self::Tx32x8 => 32,
Self::Tx64x64 | Self::Tx64x32 | Self::Tx64x16 => 64,
}
}
#[must_use]
pub const fn height(self) -> u32 {
match self {
Self::Tx4x4 | Self::Tx8x4 | Self::Tx16x4 => 4,
Self::Tx8x8 | Self::Tx4x8 | Self::Tx16x8 | Self::Tx32x8 => 8,
Self::Tx16x16 | Self::Tx8x16 | Self::Tx32x16 | Self::Tx4x16 | Self::Tx64x16 => 16,
Self::Tx32x32 | Self::Tx16x32 | Self::Tx64x32 | Self::Tx8x32 => 32,
Self::Tx64x64 | Self::Tx32x64 | Self::Tx16x64 => 64,
}
}
#[must_use]
pub const fn width_log2(self) -> u8 {
match self.width() {
4 => 2,
8 => 3,
16 => 4,
32 => 5,
64 => 6,
_ => 0,
}
}
#[must_use]
pub const fn height_log2(self) -> u8 {
match self.height() {
4 => 2,
8 => 3,
16 => 4,
32 => 5,
64 => 6,
_ => 0,
}
}
#[must_use]
pub const fn area(self) -> u32 {
self.width() * self.height()
}
#[must_use]
pub const fn is_square(self) -> bool {
self.width() == self.height()
}
#[must_use]
pub const fn sqr_size(self) -> TxSizeSqr {
match const_min_u32(self.width(), self.height()) {
4 => TxSizeSqr::Tx4x4,
8 => TxSizeSqr::Tx8x8,
16 => TxSizeSqr::Tx16x16,
32 => TxSizeSqr::Tx32x32,
64 => TxSizeSqr::Tx64x64,
_ => TxSizeSqr::Tx4x4,
}
}
#[must_use]
pub const fn from_u8(val: u8) -> Option<Self> {
match val {
0 => Some(Self::Tx4x4),
1 => Some(Self::Tx8x8),
2 => Some(Self::Tx16x16),
3 => Some(Self::Tx32x32),
4 => Some(Self::Tx64x64),
5 => Some(Self::Tx4x8),
6 => Some(Self::Tx8x4),
7 => Some(Self::Tx8x16),
8 => Some(Self::Tx16x8),
9 => Some(Self::Tx16x32),
10 => Some(Self::Tx32x16),
11 => Some(Self::Tx32x64),
12 => Some(Self::Tx64x32),
13 => Some(Self::Tx4x16),
14 => Some(Self::Tx16x4),
15 => Some(Self::Tx8x32),
16 => Some(Self::Tx32x8),
17 => Some(Self::Tx16x64),
18 => Some(Self::Tx64x16),
_ => None,
}
}
#[must_use]
pub const fn from_dimensions(width: u32, height: u32) -> Option<Self> {
match (width, height) {
(4, 4) => Some(Self::Tx4x4),
(8, 8) => Some(Self::Tx8x8),
(16, 16) => Some(Self::Tx16x16),
(32, 32) => Some(Self::Tx32x32),
(64, 64) => Some(Self::Tx64x64),
(4, 8) => Some(Self::Tx4x8),
(8, 4) => Some(Self::Tx8x4),
(8, 16) => Some(Self::Tx8x16),
(16, 8) => Some(Self::Tx16x8),
(16, 32) => Some(Self::Tx16x32),
(32, 16) => Some(Self::Tx32x16),
(32, 64) => Some(Self::Tx32x64),
(64, 32) => Some(Self::Tx64x32),
(4, 16) => Some(Self::Tx4x16),
(16, 4) => Some(Self::Tx16x4),
(8, 32) => Some(Self::Tx8x32),
(32, 8) => Some(Self::Tx32x8),
(16, 64) => Some(Self::Tx16x64),
(64, 16) => Some(Self::Tx64x16),
_ => None,
}
}
#[must_use]
#[allow(clippy::cast_possible_truncation)]
pub const fn max_eob(self) -> u16 {
self.area() as u16
}
}
#[derive(Clone, Copy, Debug, PartialEq, Eq, Default)]
pub enum TxSizeSqr {
#[default]
Tx4x4 = 0,
Tx8x8 = 1,
Tx16x16 = 2,
Tx32x32 = 3,
Tx64x64 = 4,
}
impl TxSizeSqr {
#[must_use]
pub const fn size(self) -> u32 {
match self {
Self::Tx4x4 => 4,
Self::Tx8x8 => 8,
Self::Tx16x16 => 16,
Self::Tx32x32 => 32,
Self::Tx64x64 => 64,
}
}
#[must_use]
pub const fn log2(self) -> u8 {
match self {
Self::Tx4x4 => 2,
Self::Tx8x8 => 3,
Self::Tx16x16 => 4,
Self::Tx32x32 => 5,
Self::Tx64x64 => 6,
}
}
}
#[derive(Clone, Debug, Default)]
pub struct TransformContext {
pub tx_size: TxSize,
pub tx_type: TxType,
pub plane: u8,
pub row: u32,
pub col: u32,
pub skip: bool,
pub eob: u16,
pub bit_depth: u8,
pub lossless: bool,
}
impl TransformContext {
#[must_use]
pub const fn new(tx_size: TxSize, tx_type: TxType, plane: u8) -> Self {
Self {
tx_size,
tx_type,
plane,
row: 0,
col: 0,
skip: false,
eob: 0,
bit_depth: 8,
lossless: false,
}
}
pub fn set_position(&mut self, row: u32, col: u32) {
self.row = row;
self.col = col;
}
#[must_use]
pub const fn tx_class(&self) -> TxClass {
self.tx_type.tx_class()
}
#[must_use]
pub const fn stride(&self) -> u32 {
self.tx_size.width()
}
#[must_use]
pub const fn num_coeffs(&self) -> u32 {
self.tx_size.area()
}
#[must_use]
pub const fn is_luma(&self) -> bool {
self.plane == 0
}
#[must_use]
pub const fn is_chroma(&self) -> bool {
self.plane > 0
}
#[must_use]
pub const fn effective_size(&self) -> (u32, u32) {
let w = self.tx_size.width();
let h = self.tx_size.height();
(const_min_u32(w, 32), const_min_u32(h, 32))
}
}
#[must_use]
#[allow(clippy::cast_possible_truncation)]
fn cos_value(n: usize, k: usize, size: usize) -> i32 {
let angle = PI * (2.0 * k as f64 + 1.0) * n as f64 / (2.0 * size as f64);
(angle.cos() * f64::from(1 << COS_BIT)).round() as i32
}
#[must_use]
#[allow(clippy::cast_possible_truncation)]
fn sin_value(n: usize, k: usize, size: usize) -> i32 {
let angle = PI * (2.0 * n as f64 + 1.0) * (2.0 * k as f64 + 1.0) / (4.0 * size as f64);
(angle.sin() * f64::from(1 << COS_BIT)).round() as i32
}
#[must_use]
fn round_shift_sat(value: i64, shift: u8) -> i32 {
let shifted = if shift == 0 {
value
} else {
let round = 1i64 << (shift - 1);
(value + round) >> shift
};
shifted.clamp(i64::from(TX_COEFF_MIN), i64::from(TX_COEFF_MAX)) as i32
}
#[allow(clippy::cast_possible_truncation)]
pub fn dct4(input: &[i32; 4], output: &mut [i32; 4], cos_bit: u8) {
let s0 = input[0] + input[3];
let s1 = input[1] + input[2];
let s2 = input[1] - input[2];
let s3 = input[0] - input[3];
let cos_k = [
cos_value(0, 0, 8), cos_value(1, 0, 8), cos_value(2, 0, 8), cos_value(3, 0, 8), ];
let t0 = i64::from(s0 + s1) * i64::from(cos_k[0]);
let t1 = i64::from(s0 - s1) * i64::from(cos_k[2]);
let t2 = i64::from(s2) * i64::from(cos_k[3]) + i64::from(s3) * i64::from(cos_k[1]);
let t3 = i64::from(s3) * i64::from(cos_k[3]) - i64::from(s2) * i64::from(cos_k[1]);
output[0] = round_shift_sat(t0, cos_bit);
output[1] = round_shift_sat(t2, cos_bit);
output[2] = round_shift_sat(t1, cos_bit);
output[3] = round_shift_sat(t3, cos_bit);
}
#[allow(clippy::cast_possible_truncation, clippy::similar_names)]
pub fn dct8(input: &[i32; 8], output: &mut [i32; 8], cos_bit: u8) {
let s0 = input[0] + input[7];
let s1 = input[1] + input[6];
let s2 = input[2] + input[5];
let s3 = input[3] + input[4];
let s4 = input[3] - input[4];
let s5 = input[2] - input[5];
let s6 = input[1] - input[6];
let s7 = input[0] - input[7];
let even_in = [s0, s1, s2, s3];
let mut even_out = [0i32; 4];
let e0 = even_in[0] + even_in[3];
let e1 = even_in[1] + even_in[2];
let e2 = even_in[1] - even_in[2];
let e3 = even_in[0] - even_in[3];
even_out[0] = round_shift_sat(i64::from(e0 + e1) * i64::from(cos_value(0, 0, 16)), cos_bit);
even_out[2] = round_shift_sat(i64::from(e0 - e1) * i64::from(cos_value(4, 0, 16)), cos_bit);
even_out[1] = round_shift_sat(
i64::from(e2) * i64::from(cos_value(6, 0, 16))
+ i64::from(e3) * i64::from(cos_value(2, 0, 16)),
cos_bit,
);
even_out[3] = round_shift_sat(
i64::from(e3) * i64::from(cos_value(6, 0, 16))
- i64::from(e2) * i64::from(cos_value(2, 0, 16)),
cos_bit,
);
let cos1 = cos_value(1, 0, 16);
let cos3 = cos_value(3, 0, 16);
let cos5 = cos_value(5, 0, 16);
let cos7 = cos_value(7, 0, 16);
let o0 = round_shift_sat(
i64::from(s4) * i64::from(cos7) + i64::from(s7) * i64::from(cos1),
cos_bit,
);
let o1 = round_shift_sat(
i64::from(s5) * i64::from(cos5) + i64::from(s6) * i64::from(cos3),
cos_bit,
);
let o2 = round_shift_sat(
i64::from(s6) * i64::from(cos5) - i64::from(s5) * i64::from(cos3),
cos_bit,
);
let o3 = round_shift_sat(
i64::from(s7) * i64::from(cos7) - i64::from(s4) * i64::from(cos1),
cos_bit,
);
output[0] = even_out[0];
output[1] = o0;
output[2] = even_out[1];
output[3] = o1;
output[4] = even_out[2];
output[5] = o2;
output[6] = even_out[3];
output[7] = o3;
}
#[allow(clippy::cast_possible_truncation)]
pub fn dct16(input: &[i32; 16], output: &mut [i32; 16], cos_bit: u8) {
let mut even = [0i32; 8];
let mut odd = [0i32; 8];
for i in 0..8 {
even[i] = input[i] + input[15 - i];
odd[i] = input[i] - input[15 - i];
}
let mut even_out = [0i32; 8];
dct8(&even, &mut even_out, cos_bit);
for i in 0..8 {
let cos_idx = 2 * i + 1;
let cos_val = cos_value(cos_idx, 0, 32);
output[2 * i + 1] = round_shift_sat(i64::from(odd[i]) * i64::from(cos_val), cos_bit);
}
for i in 0..8 {
output[2 * i] = even_out[i];
}
}
pub fn dct32(input: &[i32; 32], output: &mut [i32; 32], cos_bit: u8) {
let mut even = [0i32; 16];
let mut odd = [0i32; 16];
for i in 0..16 {
even[i] = input[i] + input[31 - i];
odd[i] = input[i] - input[31 - i];
}
let mut even_out = [0i32; 16];
dct16(&even, &mut even_out, cos_bit);
for i in 0..16 {
let cos_idx = 2 * i + 1;
let cos_val = cos_value(cos_idx, 0, 64);
output[2 * i + 1] = round_shift_sat(i64::from(odd[i]) * i64::from(cos_val), cos_bit);
}
for i in 0..16 {
output[2 * i] = even_out[i];
}
}
pub fn dct64(input: &[i32; 64], output: &mut [i32; 64], cos_bit: u8) {
let mut even = [0i32; 32];
let mut odd = [0i32; 32];
for i in 0..32 {
even[i] = input[i] + input[63 - i];
odd[i] = input[i] - input[63 - i];
}
let mut even_out = [0i32; 32];
dct32(&even, &mut even_out, cos_bit);
for i in 0..32 {
let cos_idx = 2 * i + 1;
let cos_val = cos_value(cos_idx, 0, 128);
output[2 * i + 1] = round_shift_sat(i64::from(odd[i]) * i64::from(cos_val), cos_bit);
}
for i in 0..32 {
output[2 * i] = even_out[i];
}
}
#[allow(clippy::cast_possible_truncation)]
pub fn adst4(input: &[i32; 4], output: &mut [i32; 4], cos_bit: u8) {
let sin_pi_9 = sin_value(0, 0, 9);
let sin_2pi_9 = sin_value(1, 0, 9);
let sin_3pi_9 = sin_value(2, 0, 9);
let sin_4pi_9 = sin_value(3, 0, 9);
let s0 = i64::from(input[0]) * i64::from(sin_pi_9);
let s1 = i64::from(input[0]) * i64::from(sin_2pi_9);
let s2 = i64::from(input[1]) * i64::from(sin_3pi_9);
let s3 = i64::from(input[2]) * i64::from(sin_4pi_9);
let s4 = i64::from(input[2]) * i64::from(sin_pi_9);
let s5 = i64::from(input[3]) * i64::from(sin_2pi_9);
let s6 = i64::from(input[3]) * i64::from(sin_4pi_9);
let t0 = s0 + s3 + s5;
let t1 = s1 + s2 - s6;
let t2 = s1 - s2 + s6;
let t3 = s0 - s3 + s4;
output[0] = round_shift_sat(t0, cos_bit);
output[1] = round_shift_sat(t1, cos_bit);
output[2] = round_shift_sat(t2, cos_bit);
output[3] = round_shift_sat(t3, cos_bit);
}
pub fn adst8(input: &[i32; 8], output: &mut [i32; 8], cos_bit: u8) {
for (i, out) in output.iter_mut().enumerate() {
let mut sum = 0i64;
for (j, &inp) in input.iter().enumerate() {
let sin_val = sin_value(i, j, 8);
sum += i64::from(inp) * i64::from(sin_val);
}
*out = round_shift_sat(sum, cos_bit);
}
}
pub fn adst16(input: &[i32; 16], output: &mut [i32; 16], cos_bit: u8) {
for (i, out) in output.iter_mut().enumerate() {
let mut sum = 0i64;
for (j, &inp) in input.iter().enumerate() {
let sin_val = sin_value(i, j, 16);
sum += i64::from(inp) * i64::from(sin_val);
}
*out = round_shift_sat(sum, cos_bit);
}
}
pub fn identity4(input: &[i32; 4], output: &mut [i32; 4]) {
for (i, &val) in input.iter().enumerate() {
output[i] = val * 2; }
}
pub fn identity8(input: &[i32; 8], output: &mut [i32; 8]) {
for (i, &val) in input.iter().enumerate() {
output[i] = val * 2;
}
}
pub fn identity16(input: &[i32; 16], output: &mut [i32; 16]) {
for (i, &val) in input.iter().enumerate() {
output[i] = val * 2;
}
}
pub fn identity32(input: &[i32; 32], output: &mut [i32; 32]) {
for (i, &val) in input.iter().enumerate() {
output[i] = val * 4; }
}
pub fn idct4(input: &[i32; 4], output: &mut [i32; 4], cos_bit: u8) {
let cos0 = cos_value(0, 0, 8);
let cos1 = cos_value(1, 0, 8);
let cos2 = cos_value(2, 0, 8);
let cos3 = cos_value(3, 0, 8);
let t0 = i64::from(input[0]) * i64::from(cos0);
let t1 = i64::from(input[2]) * i64::from(cos2);
let t2 = i64::from(input[1]) * i64::from(cos1) + i64::from(input[3]) * i64::from(cos3);
let t3 = i64::from(input[1]) * i64::from(cos3) - i64::from(input[3]) * i64::from(cos1);
let s0 = round_shift_sat(t0 + t1, cos_bit);
let s1 = round_shift_sat(t0 - t1, cos_bit);
let s2 = round_shift_sat(t2, cos_bit);
let s3 = round_shift_sat(t3, cos_bit);
output[0] = s0 + s2;
output[1] = s1 + s3;
output[2] = s1 - s3;
output[3] = s0 - s2;
}
pub fn idct8(input: &[i32; 8], output: &mut [i32; 8], cos_bit: u8) {
for (i, out) in output.iter_mut().enumerate() {
let mut sum = 0i64;
for (j, &inp) in input.iter().enumerate() {
let cos_val = cos_value(j, i, 8);
sum += i64::from(inp) * i64::from(cos_val);
}
*out = round_shift_sat(sum, cos_bit);
}
}
pub fn idct16(input: &[i32; 16], output: &mut [i32; 16], cos_bit: u8) {
for (i, out) in output.iter_mut().enumerate() {
let mut sum = 0i64;
for (j, &inp) in input.iter().enumerate() {
let cos_val = cos_value(j, i, 16);
sum += i64::from(inp) * i64::from(cos_val);
}
*out = round_shift_sat(sum, cos_bit);
}
}
pub fn idct32(input: &[i32; 32], output: &mut [i32; 32], cos_bit: u8) {
for (i, out) in output.iter_mut().enumerate() {
let mut sum = 0i64;
for (j, &inp) in input.iter().enumerate() {
let cos_val = cos_value(j, i, 32);
sum += i64::from(inp) * i64::from(cos_val);
}
*out = round_shift_sat(sum, cos_bit);
}
}
pub fn idct64(input: &[i32; 64], output: &mut [i32; 64], cos_bit: u8) {
for (i, out) in output.iter_mut().enumerate() {
let mut sum = 0i64;
for (j, &inp) in input.iter().enumerate() {
let cos_val = cos_value(j, i, 64);
sum += i64::from(inp) * i64::from(cos_val);
}
*out = round_shift_sat(sum, cos_bit);
}
}
pub fn iadst4(input: &[i32; 4], output: &mut [i32; 4], cos_bit: u8) {
adst4(input, output, cos_bit);
}
pub fn iadst8(input: &[i32; 8], output: &mut [i32; 8], cos_bit: u8) {
adst8(input, output, cos_bit);
}
pub fn iadst16(input: &[i32; 16], output: &mut [i32; 16], cos_bit: u8) {
adst16(input, output, cos_bit);
}
#[derive(Clone, Debug)]
pub struct Transform2D {
buffer: Vec<i32>,
tx_size: TxSize,
tx_type: TxType,
}
impl Transform2D {
#[must_use]
pub fn new(tx_size: TxSize, tx_type: TxType) -> Self {
let area = tx_size.area() as usize;
Self {
buffer: vec![0; area],
tx_size,
tx_type,
}
}
pub fn inverse(&mut self, input: &[i32], output: &mut [i32]) {
let width = self.tx_size.width() as usize;
let height = self.tx_size.height() as usize;
let _cos_bit = COS_BIT;
for row in 0..height {
let row_start = row * width;
self.apply_row_inverse(&input[row_start..row_start + width], row);
}
for col in 0..width {
self.apply_col_inverse(col, &mut output[col..], width);
}
}
fn apply_row_inverse(&mut self, input: &[i32], row: usize) {
let width = self.tx_size.width() as usize;
let row_type = self.tx_type.row_type();
let cos_bit = COS_BIT;
let mut row_out = vec![0i32; width];
match (row_type, width) {
(TxType1D::Dct, 4) => {
let mut inp = [0i32; 4];
let mut out = [0i32; 4];
inp.copy_from_slice(input);
idct4(&inp, &mut out, cos_bit);
row_out.copy_from_slice(&out);
}
(TxType1D::Dct, 8) => {
let mut inp = [0i32; 8];
let mut out = [0i32; 8];
inp.copy_from_slice(input);
idct8(&inp, &mut out, cos_bit);
row_out.copy_from_slice(&out);
}
(TxType1D::Adst, 4) => {
let mut inp = [0i32; 4];
let mut out = [0i32; 4];
inp.copy_from_slice(input);
iadst4(&inp, &mut out, cos_bit);
row_out.copy_from_slice(&out);
}
(TxType1D::Adst, 8) => {
let mut inp = [0i32; 8];
let mut out = [0i32; 8];
inp.copy_from_slice(input);
iadst8(&inp, &mut out, cos_bit);
row_out.copy_from_slice(&out);
}
(TxType1D::Identity, 4) => {
let mut inp = [0i32; 4];
let mut out = [0i32; 4];
inp.copy_from_slice(input);
identity4(&inp, &mut out);
row_out.copy_from_slice(&out);
}
(TxType1D::Identity, 8) => {
let mut inp = [0i32; 8];
let mut out = [0i32; 8];
inp.copy_from_slice(input);
identity8(&inp, &mut out);
row_out.copy_from_slice(&out);
}
(TxType1D::FlipAdst, n) => {
let mut temp = vec![0i32; n];
match n {
4 => {
let mut inp = [0i32; 4];
let mut out = [0i32; 4];
inp.copy_from_slice(input);
iadst4(&inp, &mut out, cos_bit);
temp.copy_from_slice(&out);
}
8 => {
let mut inp = [0i32; 8];
let mut out = [0i32; 8];
inp.copy_from_slice(input);
iadst8(&inp, &mut out, cos_bit);
temp.copy_from_slice(&out);
}
_ => temp.copy_from_slice(input),
}
for i in 0..n {
row_out[i] = temp[n - 1 - i];
}
}
_ => {
row_out[..width].copy_from_slice(&input[..width]);
}
}
let row_start = row * width;
self.buffer[row_start..row_start + width].copy_from_slice(&row_out);
}
fn apply_col_inverse(&self, col: usize, output: &mut [i32], stride: usize) {
let width = self.tx_size.width() as usize;
let height = self.tx_size.height() as usize;
let col_type = self.tx_type.col_type();
let cos_bit = COS_BIT;
let mut col_in = vec![0i32; height];
for row in 0..height {
col_in[row] = self.buffer[row * width + col];
}
let mut col_out = vec![0i32; height];
match (col_type, height) {
(TxType1D::Dct, 4) => {
let mut inp = [0i32; 4];
let mut out = [0i32; 4];
inp.copy_from_slice(&col_in);
idct4(&inp, &mut out, cos_bit);
col_out.copy_from_slice(&out);
}
(TxType1D::Dct, 8) => {
let mut inp = [0i32; 8];
let mut out = [0i32; 8];
inp.copy_from_slice(&col_in);
idct8(&inp, &mut out, cos_bit);
col_out.copy_from_slice(&out);
}
(TxType1D::Adst, 4) => {
let mut inp = [0i32; 4];
let mut out = [0i32; 4];
inp.copy_from_slice(&col_in);
iadst4(&inp, &mut out, cos_bit);
col_out.copy_from_slice(&out);
}
(TxType1D::Adst, 8) => {
let mut inp = [0i32; 8];
let mut out = [0i32; 8];
inp.copy_from_slice(&col_in);
iadst8(&inp, &mut out, cos_bit);
col_out.copy_from_slice(&out);
}
(TxType1D::Identity, 4) => {
let mut inp = [0i32; 4];
let mut out = [0i32; 4];
inp.copy_from_slice(&col_in);
identity4(&inp, &mut out);
col_out.copy_from_slice(&out);
}
(TxType1D::Identity, 8) => {
let mut inp = [0i32; 8];
let mut out = [0i32; 8];
inp.copy_from_slice(&col_in);
identity8(&inp, &mut out);
col_out.copy_from_slice(&out);
}
(TxType1D::FlipAdst, n) => {
let mut temp = vec![0i32; n];
match n {
4 => {
let mut inp = [0i32; 4];
let mut out = [0i32; 4];
inp.copy_from_slice(&col_in);
iadst4(&inp, &mut out, cos_bit);
temp.copy_from_slice(&out);
}
8 => {
let mut inp = [0i32; 8];
let mut out = [0i32; 8];
inp.copy_from_slice(&col_in);
iadst8(&inp, &mut out, cos_bit);
temp.copy_from_slice(&out);
}
_ => temp.copy_from_slice(&col_in),
}
for i in 0..n {
col_out[i] = temp[n - 1 - i];
}
}
_ => {
col_out.copy_from_slice(&col_in);
}
}
for row in 0..height {
output[row * stride] = col_out[row];
}
}
}
pub fn flip_horizontal(coeffs: &mut [i32], width: usize, height: usize) {
for row in 0..height {
let row_start = row * width;
coeffs[row_start..row_start + width].reverse();
}
}
pub fn flip_vertical(coeffs: &mut [i32], width: usize, height: usize) {
for col in 0..width {
for row in 0..height / 2 {
let top = row * width + col;
let bottom = (height - 1 - row) * width + col;
coeffs.swap(top, bottom);
}
}
}
pub fn wht4x4(input: &[i32; 16], output: &mut [i32; 16]) {
for (i, &val) in input.iter().enumerate() {
output[i] = val;
}
for row in 0..4 {
let i = row * 4;
let a = output[i] + output[i + 1];
let b = output[i + 2] + output[i + 3];
let c = output[i] - output[i + 1];
let d = output[i + 2] - output[i + 3];
output[i] = a + b;
output[i + 1] = c + d;
output[i + 2] = a - b;
output[i + 3] = c - d;
}
for col in 0..4 {
let a = output[col] + output[col + 4];
let b = output[col + 8] + output[col + 12];
let c = output[col] - output[col + 4];
let d = output[col + 8] - output[col + 12];
output[col] = (a + b) >> 2;
output[col + 4] = (c + d) >> 2;
output[col + 8] = (a - b) >> 2;
output[col + 12] = (c - d) >> 2;
}
}
pub fn iwht4x4(input: &[i32; 16], output: &mut [i32; 16]) {
wht4x4(input, output);
}
#[must_use]
pub const fn get_reduced_tx_size(tx_size: TxSize) -> (u32, u32) {
let width = tx_size.width();
let height = tx_size.height();
(const_min_u32(width, 32), const_min_u32(height, 32))
}
#[must_use]
pub const fn needs_reduction(tx_size: TxSize) -> bool {
tx_size.width() > 32 || tx_size.height() > 32
}
#[must_use]
pub const fn get_max_nonzero_coeffs(tx_size: TxSize) -> u32 {
let (w, h) = get_reduced_tx_size(tx_size);
w * h
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_tx_type_components() {
assert_eq!(TxType::DctDct.row_type(), TxType1D::Dct);
assert_eq!(TxType::DctDct.col_type(), TxType1D::Dct);
assert_eq!(TxType::AdstDct.row_type(), TxType1D::Adst);
assert_eq!(TxType::AdstDct.col_type(), TxType1D::Dct);
assert_eq!(TxType::IdtxIdtx.row_type(), TxType1D::Identity);
assert_eq!(TxType::IdtxIdtx.col_type(), TxType1D::Identity);
}
#[test]
fn test_tx_size_dimensions() {
assert_eq!(TxSize::Tx4x4.width(), 4);
assert_eq!(TxSize::Tx4x4.height(), 4);
assert_eq!(TxSize::Tx4x8.width(), 4);
assert_eq!(TxSize::Tx4x8.height(), 8);
assert_eq!(TxSize::Tx64x64.width(), 64);
assert_eq!(TxSize::Tx64x64.height(), 64);
}
#[test]
fn test_tx_size_log2() {
assert_eq!(TxSize::Tx4x4.width_log2(), 2);
assert_eq!(TxSize::Tx8x8.width_log2(), 3);
assert_eq!(TxSize::Tx16x16.width_log2(), 4);
assert_eq!(TxSize::Tx32x32.width_log2(), 5);
assert_eq!(TxSize::Tx64x64.width_log2(), 6);
}
#[test]
fn test_tx_size_area() {
assert_eq!(TxSize::Tx4x4.area(), 16);
assert_eq!(TxSize::Tx8x8.area(), 64);
assert_eq!(TxSize::Tx4x8.area(), 32);
}
#[test]
fn test_tx_size_is_square() {
assert!(TxSize::Tx4x4.is_square());
assert!(TxSize::Tx8x8.is_square());
assert!(!TxSize::Tx4x8.is_square());
assert!(!TxSize::Tx8x4.is_square());
}
#[test]
fn test_tx_class() {
assert_eq!(TxType::DctDct.tx_class(), TxClass::Class2D);
assert_eq!(TxType::IdtxDct.tx_class(), TxClass::ClassVert);
assert_eq!(TxType::DctIdtx.tx_class(), TxClass::ClassHoriz);
}
#[test]
fn test_tx_type_from_u8() {
assert_eq!(TxType::from_u8(0), Some(TxType::DctDct));
assert_eq!(TxType::from_u8(15), Some(TxType::IdtxIdtx));
assert_eq!(TxType::from_u8(16), None);
}
#[test]
fn test_tx_size_from_u8() {
assert_eq!(TxSize::from_u8(0), Some(TxSize::Tx4x4));
assert_eq!(TxSize::from_u8(18), Some(TxSize::Tx64x16));
assert_eq!(TxSize::from_u8(19), None);
}
#[test]
fn test_tx_size_from_dimensions() {
assert_eq!(TxSize::from_dimensions(4, 4), Some(TxSize::Tx4x4));
assert_eq!(TxSize::from_dimensions(64, 64), Some(TxSize::Tx64x64));
assert_eq!(TxSize::from_dimensions(4, 8), Some(TxSize::Tx4x8));
assert_eq!(TxSize::from_dimensions(3, 3), None);
}
#[test]
fn test_transform_context() {
let ctx = TransformContext::new(TxSize::Tx8x8, TxType::DctDct, 0);
assert_eq!(ctx.stride(), 8);
assert_eq!(ctx.num_coeffs(), 64);
assert!(ctx.is_luma());
assert!(!ctx.is_chroma());
}
#[test]
fn test_dct4_identity() {
let input = [1, 1, 1, 1];
let mut output = [0i32; 4];
dct4(&input, &mut output, COS_BIT);
assert!(output[0].abs() > output[1].abs());
}
#[test]
fn test_idct4_dct4_roundtrip() {
let input = [100, 50, -30, 80];
let mut dct_out = [0i32; 4];
let mut idct_out = [0i32; 4];
dct4(&input, &mut dct_out, COS_BIT);
idct4(&dct_out, &mut idct_out, COS_BIT);
for i in 0..4 {
let diff = (input[i] - idct_out[i]).abs();
assert!(diff < 500, "Roundtrip error too large at {i}: {diff}");
}
}
#[test]
fn test_identity_transform() {
let input = [1, 2, 3, 4];
let mut output = [0i32; 4];
identity4(&input, &mut output);
for i in 0..4 {
assert_eq!(output[i], input[i] * 2);
}
}
#[test]
fn test_wht4x4() {
let input = [1i32; 16];
let mut output = [0i32; 16];
wht4x4(&input, &mut output);
assert_ne!(output[0], 0);
}
#[test]
fn test_reduced_tx_size() {
assert_eq!(get_reduced_tx_size(TxSize::Tx4x4), (4, 4));
assert_eq!(get_reduced_tx_size(TxSize::Tx64x64), (32, 32));
assert_eq!(get_reduced_tx_size(TxSize::Tx64x32), (32, 32));
}
#[test]
fn test_needs_reduction() {
assert!(!needs_reduction(TxSize::Tx32x32));
assert!(needs_reduction(TxSize::Tx64x64));
assert!(needs_reduction(TxSize::Tx64x32));
}
#[test]
fn test_max_nonzero_coeffs() {
assert_eq!(get_max_nonzero_coeffs(TxSize::Tx4x4), 16);
assert_eq!(get_max_nonzero_coeffs(TxSize::Tx64x64), 1024); }
#[test]
fn test_tx_type_valid_for_size() {
assert!(TxType::DctDct.is_valid_for_size(TxSize::Tx64x64));
assert!(!TxType::IdtxIdtx.is_valid_for_size(TxSize::Tx64x64));
assert!(TxType::IdtxIdtx.is_valid_for_size(TxSize::Tx32x32));
}
#[test]
fn test_flip_horizontal() {
let mut coeffs = [1, 2, 3, 4, 5, 6, 7, 8];
flip_horizontal(&mut coeffs, 4, 2);
assert_eq!(coeffs, [4, 3, 2, 1, 8, 7, 6, 5]);
}
#[test]
fn test_flip_vertical() {
let mut coeffs = [1, 2, 3, 4, 5, 6, 7, 8];
flip_vertical(&mut coeffs, 4, 2);
assert_eq!(coeffs, [5, 6, 7, 8, 1, 2, 3, 4]);
}
#[test]
fn test_transform_2d_new() {
let tx = Transform2D::new(TxSize::Tx8x8, TxType::DctDct);
assert_eq!(tx.buffer.len(), 64);
}
#[test]
fn test_tx_size_sqr() {
assert_eq!(TxSizeSqr::Tx4x4.size(), 4);
assert_eq!(TxSizeSqr::Tx8x8.size(), 8);
assert_eq!(TxSizeSqr::Tx64x64.log2(), 6);
}
#[test]
fn test_cos_value() {
let cos0 = cos_value(0, 0, 8);
assert!(cos0 > 0);
let cos_half_pi = cos_value(4, 0, 8);
assert!(cos_half_pi.abs() < cos0);
}
#[test]
fn test_round_shift_sat() {
assert_eq!(round_shift_sat(100, 2), 25);
assert_eq!(round_shift_sat(100, 1), 50);
let max_plus = i64::from(TX_COEFF_MAX) * 4; assert_eq!(round_shift_sat(max_plus, 1), TX_COEFF_MAX);
let min_minus = i64::from(TX_COEFF_MIN) * 4; assert_eq!(round_shift_sat(min_minus, 1), TX_COEFF_MIN);
}
#[test]
fn test_constants() {
assert_eq!(TX_TYPES, 16);
assert_eq!(TX_SIZES, 19);
assert_eq!(MAX_TX_SIZE, 64);
assert_eq!(MAX_TX_SQUARE, 4096);
}
}