use gamut_dsp::{
clip3, forward_adst, forward_dct, forward_identity, inverse_adst, inverse_dct,
inverse_identity, round2,
};
#[derive(Clone, Copy, Debug, PartialEq, Eq)]
#[allow(missing_docs)]
pub enum TxSize {
Tx4x4,
Tx8x8,
Tx16x16,
Tx32x32,
Tx64x64,
Tx4x8,
Tx8x4,
Tx8x16,
Tx16x8,
Tx16x32,
Tx32x16,
Tx32x64,
Tx64x32,
Tx4x16,
Tx16x4,
Tx8x32,
Tx32x8,
Tx16x64,
Tx64x16,
}
const TX_WIDTH_LOG2: [u32; 19] = [2, 3, 4, 5, 6, 2, 3, 3, 4, 4, 5, 5, 6, 2, 4, 3, 5, 4, 6];
const TX_HEIGHT_LOG2: [u32; 19] = [2, 3, 4, 5, 6, 3, 2, 4, 3, 5, 4, 6, 5, 4, 2, 5, 3, 6, 4];
const TRANSFORM_ROW_SHIFT: [u32; 19] = [0, 1, 2, 2, 2, 0, 0, 1, 1, 1, 1, 1, 1, 1, 1, 2, 2, 2, 2];
impl TxSize {
#[must_use]
pub fn log2_width(self) -> u32 {
TX_WIDTH_LOG2[self as usize]
}
#[must_use]
pub fn log2_height(self) -> u32 {
TX_HEIGHT_LOG2[self as usize]
}
#[must_use]
pub fn width(self) -> usize {
1 << self.log2_width()
}
#[must_use]
pub fn height(self) -> usize {
1 << self.log2_height()
}
#[must_use]
pub fn dq_denom(self) -> i32 {
match self {
TxSize::Tx32x32
| TxSize::Tx16x32
| TxSize::Tx32x16
| TxSize::Tx16x64
| TxSize::Tx64x16 => 2,
TxSize::Tx64x64 | TxSize::Tx32x64 | TxSize::Tx64x32 => 4,
_ => 1,
}
}
fn row_shift(self) -> u32 {
TRANSFORM_ROW_SHIFT[self as usize]
}
}
#[derive(Clone, Copy, Debug, PartialEq, Eq)]
enum Tx1d {
Dct,
Adst,
Identity,
}
#[derive(Clone, Copy, Debug, PartialEq, Eq)]
#[allow(missing_docs)]
pub enum TxType {
DctDct,
AdstDct,
DctAdst,
AdstAdst,
FlipadstDct,
DctFlipadst,
FlipadstFlipadst,
AdstFlipadst,
FlipadstAdst,
Idtx,
VDct,
HDct,
VAdst,
HAdst,
VFlipadst,
HFlipadst,
}
impl TxType {
fn row_1d(self) -> Tx1d {
use TxType::*;
match self {
DctDct | AdstDct | FlipadstDct | HDct => Tx1d::Dct,
DctAdst | AdstAdst | DctFlipadst | FlipadstFlipadst | AdstFlipadst | FlipadstAdst
| HAdst | HFlipadst => Tx1d::Adst,
Idtx | VDct | VAdst | VFlipadst => Tx1d::Identity,
}
}
fn col_1d(self) -> Tx1d {
use TxType::*;
match self {
DctDct | DctAdst | DctFlipadst | VDct => Tx1d::Dct,
AdstDct | AdstAdst | FlipadstDct | FlipadstFlipadst | AdstFlipadst | FlipadstAdst
| VAdst | VFlipadst => Tx1d::Adst,
Idtx | HDct | HAdst | HFlipadst => Tx1d::Identity,
}
}
#[must_use]
pub fn flip_ud(self) -> bool {
use TxType::*;
matches!(
self,
FlipadstDct | FlipadstAdst | VFlipadst | FlipadstFlipadst
)
}
#[must_use]
pub fn flip_lr(self) -> bool {
use TxType::*;
matches!(
self,
DctFlipadst | AdstFlipadst | HFlipadst | FlipadstFlipadst
)
}
}
fn inverse_1d(kind: Tx1d, t: &mut [i64], n: u32, r: u32) {
match kind {
Tx1d::Dct => inverse_dct(t, n, r),
Tx1d::Adst => inverse_adst(t, n, r),
Tx1d::Identity => inverse_identity(t, n),
}
}
fn forward_1d(kind: Tx1d, t: &mut [i64], n: u32) {
match kind {
Tx1d::Dct => forward_dct(t, n),
Tx1d::Adst => forward_adst(t, n),
Tx1d::Identity => forward_identity(t, n),
}
}
#[must_use]
pub fn inverse_transform_2d(dequant: &[i32], tx: TxSize, ty: TxType, bit_depth: u32) -> Vec<i32> {
let (log2w, log2h) = (tx.log2_width(), tx.log2_height());
let (w, h) = (tx.width(), tx.height());
assert!(dequant.len() >= w * h, "inverse_transform_2d: short input");
let row_shift = tx.row_shift();
let col_shift = 4u32;
let row_clamp = bit_depth + 8;
let col_clamp = (bit_depth + 6).max(16);
let rect = (log2w as i32 - log2h as i32).abs() == 1;
let mut resid = vec![0i64; w * h];
let mut t = [0i64; 64];
for i in 0..h {
for j in 0..w {
t[j] = if i < 32 && j < 32 {
i64::from(dequant[i * w + j])
} else {
0
};
}
if rect {
for v in &mut t[..w] {
*v = round2(*v * 2896, 12);
}
}
inverse_1d(ty.row_1d(), &mut t[..w], log2w, row_clamp);
for j in 0..w {
resid[i * w + j] = round2(t[j], row_shift);
}
}
let lim = 1i64 << (col_clamp - 1);
for v in &mut resid {
*v = clip3(-lim, lim - 1, *v);
}
for j in 0..w {
for i in 0..h {
t[i] = resid[i * w + j];
}
inverse_1d(ty.col_1d(), &mut t[..h], log2h, col_clamp);
for i in 0..h {
resid[i * w + j] = round2(t[i], col_shift);
}
}
resid.iter().map(|&v| v as i32).collect()
}
#[must_use]
pub fn forward_transform_2d(residual: &[i32], tx: TxSize, ty: TxType) -> Vec<i32> {
let (log2w, log2h) = (tx.log2_width(), tx.log2_height());
let (w, h) = (tx.width(), tx.height());
assert!(residual.len() >= w * h, "forward_transform_2d: short input");
let rect = (log2w as i32 - log2h as i32).abs() == 1;
let mut coeff = vec![0i64; w * h];
let mut t = [0i64; 64];
for j in 0..w {
for i in 0..h {
t[i] = i64::from(residual[i * w + j]);
}
forward_1d(ty.col_1d(), &mut t[..h], log2h);
for i in 0..h {
coeff[i * w + j] = t[i];
}
}
for i in 0..h {
for j in 0..w {
t[j] = coeff[i * w + j];
}
forward_1d(ty.row_1d(), &mut t[..w], log2w);
for j in 0..w {
coeff[i * w + j] = t[j];
}
}
let mut shift = FWD_SHIFT[tx as usize];
if ty.row_1d() == Tx1d::Identity {
shift += log2w as i32 - 1;
}
if ty.col_1d() == Tx1d::Identity {
shift += log2h as i32 - 1;
}
coeff
.iter()
.map(|&c| {
let mut v = c;
if rect {
v = round2(v * 2896, 12);
}
v = if shift >= 0 {
v << shift
} else {
round2(v, (-shift) as u32)
};
v as i32
})
.collect()
}
const FWD_SHIFT: [i32; 19] = [
2, 1, 0, -2, -4, 2, 2, 1, 1, -1, -1, -3, -3, 1, 1, 0, 0, -2, -2,
];
#[cfg(test)]
mod tests {
use super::*;
struct Lcg(u64);
impl Lcg {
fn next(&mut self) -> u64 {
self.0 = self
.0
.wrapping_mul(6364136223846793005)
.wrapping_add(1442695040888963407);
self.0
}
fn resid(&mut self) -> i32 {
(self.next() >> 40) as i32 % 256 - 128
}
}
const ALL_SIZES: [TxSize; 19] = [
TxSize::Tx4x4,
TxSize::Tx8x8,
TxSize::Tx16x16,
TxSize::Tx32x32,
TxSize::Tx64x64,
TxSize::Tx4x8,
TxSize::Tx8x4,
TxSize::Tx8x16,
TxSize::Tx16x8,
TxSize::Tx16x32,
TxSize::Tx32x16,
TxSize::Tx32x64,
TxSize::Tx64x32,
TxSize::Tx4x16,
TxSize::Tx16x4,
TxSize::Tx8x32,
TxSize::Tx32x8,
TxSize::Tx16x64,
TxSize::Tx64x16,
];
fn smooth_residual(w: usize, h: usize, rng: &mut Lcg) -> Vec<i32> {
let a = rng.resid() / 2;
let b = rng.resid() / 2;
let (cw, ch) = (w as i32, h as i32);
(0..w * h)
.map(|p| {
let (x, y) = ((p % w) as i32, (p / w) as i32);
(a * (2 * x - cw + 1)) / cw + (b * (2 * y - ch + 1)) / ch
})
.collect()
}
#[test]
fn sizes_have_consistent_dimensions() {
for tx in ALL_SIZES {
assert_eq!(tx.width(), 1 << tx.log2_width());
assert_eq!(tx.height(), 1 << tx.log2_height());
assert!(tx.width() >= 4 && tx.width() <= 64);
assert!(tx.height() >= 4 && tx.height() <= 64);
}
}
#[test]
fn dct_dct_round_trip_is_near_identity() {
let mut rng = Lcg(0x00c0_ffee_1234_5678);
for tx in ALL_SIZES {
let (w, h) = (tx.width(), tx.height());
let mut max_err = 0i32;
for _ in 0..20 {
let residual = smooth_residual(w, h, &mut rng);
let coeff = forward_transform_2d(&residual, tx, TxType::DctDct);
let recon = inverse_transform_2d(&coeff, tx, TxType::DctDct, 8);
for (r, o) in residual.iter().zip(&recon) {
max_err = max_err.max((r - o).abs());
}
}
let bound = 4 + 2 * (tx.log2_width() as i32 + tx.log2_height() as i32);
assert!(
max_err <= bound,
"{tx:?}: round-trip max error {max_err} exceeds {bound}",
);
}
}
#[test]
fn adst_and_identity_round_trip() {
let mut rng = Lcg(0xdead_1010_2020_3030);
let cases = [
(TxSize::Tx4x4, TxType::AdstAdst),
(TxSize::Tx8x8, TxType::AdstDct),
(TxSize::Tx16x16, TxType::DctAdst),
(TxSize::Tx8x8, TxType::Idtx),
(TxSize::Tx16x16, TxType::VDct),
(TxSize::Tx32x32, TxType::Idtx),
(TxSize::Tx8x16, TxType::HAdst),
(TxSize::Tx16x8, TxType::VFlipadst),
];
for (tx, ty) in cases {
let (w, h) = (tx.width(), tx.height());
let mut max_err = 0i32;
for _ in 0..40 {
let residual = smooth_residual(w, h, &mut rng);
let coeff = forward_transform_2d(&residual, tx, ty);
let recon = inverse_transform_2d(&coeff, tx, ty, 8);
for (r, o) in residual.iter().zip(&recon) {
max_err = max_err.max((r - o).abs());
}
}
let bound = 6 + 2 * (tx.log2_width() as i32 + tx.log2_height() as i32);
assert!(
max_err <= bound,
"{tx:?}/{ty:?}: round-trip error {max_err} > {bound}"
);
}
}
#[test]
fn flip_flags_match_spec() {
assert!(TxType::FlipadstDct.flip_ud() && !TxType::FlipadstDct.flip_lr());
assert!(TxType::DctFlipadst.flip_lr() && !TxType::DctFlipadst.flip_ud());
assert!(TxType::FlipadstFlipadst.flip_ud() && TxType::FlipadstFlipadst.flip_lr());
assert!(!TxType::DctDct.flip_ud() && !TxType::DctDct.flip_lr());
}
#[test]
fn dq_denom_matches_spec() {
assert_eq!(TxSize::Tx4x4.dq_denom(), 1);
assert_eq!(TxSize::Tx32x32.dq_denom(), 2);
assert_eq!(TxSize::Tx64x64.dq_denom(), 4);
assert_eq!(TxSize::Tx16x32.dq_denom(), 2);
assert_eq!(TxSize::Tx32x64.dq_denom(), 4);
assert_eq!(TxSize::Tx8x8.dq_denom(), 1);
}
}