#![allow(clippy::excessive_precision)]
use alloc::boxed::Box;
use alloc::vec;
use alloc::vec::Vec;
use once_cell::race::OnceBox;
pub const NUM_VALID_STRATEGIES: usize = 19;
pub const INV_DC_QUANT: [f32; 3] = [4096.0, 512.0, 256.0];
#[allow(dead_code)]
pub const DC_QUANT: [f32; 3] = [
1.0 / 4096.0, 1.0 / 512.0, 1.0 / 256.0, ];
fn band_mult(v: f64) -> f64 {
if v > 0.0 { 1.0 + v } else { 1.0 / (1.0 - v) }
}
fn interpolate_band(pos: f64, bands: &[f64]) -> f64 {
let len = bands.len();
if len == 1 {
return bands[0];
}
let idx = (pos as usize).min(len - 2);
let frac = pos - idx as f64;
let a = bands[idx];
let b = bands[idx + 1];
a * (b / a).powf(frac)
}
fn generate_dct_quant_weights_rect(
rows: usize,
cols: usize,
band_params: &[&[f64]; 3],
num_bands: usize,
) -> Vec<f32> {
let num = rows * cols;
let total = 3 * num;
let mut out = vec![0.0f32; total];
let sqrt2 = core::f64::consts::SQRT_2;
let scale = (num_bands as f64 - 1.0) / (sqrt2 + 1e-6);
let rcpcol = scale / (cols as f64 - 1.0);
let rcprow = scale / (rows as f64 - 1.0);
for c in 0..3 {
let params = band_params[c];
let mut bands = vec![0.0f64; num_bands];
bands[0] = params[0];
for i in 1..num_bands {
bands[i] = bands[i - 1] * band_mult(params[i]);
}
for y in 0..rows {
let dy = y as f64 * rcprow;
let dy2 = dy * dy;
for x in 0..cols {
let dx = x as f64 * rcpcol;
let scaled_distance = (dx * dx + dy2).sqrt();
let dequant_weight = interpolate_band(scaled_distance, &bands);
let quant_weight = 1.0 / dequant_weight;
out[c * num + y * cols + x] = quant_weight as f32;
}
}
}
out
}
const DCT8_PARAMS: [[f64; 6]; 3] = [
[3150.0, 0.0, -0.4, -0.4, -0.4, -2.0],
[560.0, 0.0, -0.3, -0.3, -0.3, -0.3],
[512.0, -2.0, -1.0, 0.0, -1.0, -2.0],
];
const DCT16X16_PARAMS: [[f64; 7]; 3] = [
[
8996.8725711814115328,
-1.3000777393353804,
-0.49424529824571225,
-0.439093774457103443,
-0.6350101832695744,
-0.90177264050827612,
-1.6162099239887414,
],
[
3191.48366296844234752,
-0.67424582104194355,
-0.80745813428471001,
-0.44925837484843441,
-0.35865440981033403,
-0.31322389111877305,
-0.37615025315725483,
],
[
1157.50408145487200256,
-2.0531423165804414,
-1.4,
-0.50687130033378396,
-0.42708730624733904,
-1.4856834539296244,
-4.9209142884401604,
],
];
const DCT16X8_PARAMS: [[f64; 7]; 3] = [
[7240.7734393502, -0.7, -0.7, -0.2, -0.2, -0.2, -0.5],
[1448.15468787004, -0.5, -0.5, -0.5, -0.2, -0.2, -0.2],
[506.854140754517, -1.4, -0.2, -0.5, -0.5, -1.5, -3.6],
];
const DCT32X32_BAND_PARAMS: [[f64; 8]; 3] = [
[
15718.40830982518931456,
-1.025,
-0.98,
-0.9012,
-0.4,
-0.48819395464,
-0.421064,
-0.27,
],
[
7305.7636810695983104,
-0.8041958212306401,
-0.7633036457487539,
-0.55660379990111464,
-0.49785304658857626,
-0.43699592683512467,
-0.40180866526242109,
-0.27321683125358037,
],
[
3803.53173721215041536,
-3.060733579805728,
-2.0413270132490346,
-2.0235650159727417,
-0.5495389509954993,
-0.4,
-0.4,
-0.3,
],
];
const DCT16X32_BAND_PARAMS: [[f64; 8]; 3] = [
[
13844.97076442300573,
-0.97113799999999995,
-0.658,
-0.42026,
-0.22712,
-0.2206,
-0.226,
-0.6,
],
[
4798.964084220744293,
-0.61125308982767057,
-0.83770786552491361,
-0.79014862079498627,
-0.2692727459704829,
-0.38272769465388551,
-0.22924222653091453,
-0.20719098826199578,
],
[
1807.236946760964614,
-1.2,
-1.2,
-0.7,
-0.7,
-0.7,
-0.4,
-0.5,
],
];
const DCT64X64_BAND_PARAMS: [[f64; 8]; 3] = [
[
23966.16652984486,
-1.025,
-0.78,
-0.65012,
-0.19041574084286472,
-0.20819395464,
-0.421064,
-0.32733845535848671,
],
[
8380.191483900904,
-0.3041958212306401,
-0.3633036457487539,
-0.35660379990111464,
-0.3443074455424403,
-0.33699592683512467,
-0.30180866526242109,
-0.27321683125358037,
],
[
4493.02378009847706,
-1.2,
-1.2,
-0.8,
-0.7,
-0.7,
-0.4,
-0.5,
],
];
const DCT32X64_BAND_PARAMS: [[f64; 8]; 3] = [
[
15358.898049332399,
-1.025,
-0.78,
-0.65012,
-0.19041574084286472,
-0.20819395464,
-0.421064,
-0.32733845535848671,
],
[
5597.36051615065299,
-0.3041958212306401,
-0.3633036457487539,
-0.35660379990111464,
-0.3443074455424403,
-0.33699592683512467,
-0.30180866526242109,
-0.27321683125358037,
],
[2919.961618960011, -1.2, -1.2, -0.8, -0.7, -0.7, -0.4, -0.5],
];
const DCT4X8_BAND_PARAMS: [[f64; 4]; 3] = [
[2198.0505, -0.96269625, -0.7619425, -0.65511405],
[764.36554, -0.926302, -0.967523, -0.2784529],
[527.10754, -1.4594386, -1.4500821, -1.5843723],
];
const DCT4_BAND_PARAMS: [[f64; 4]; 3] = [
[2200.0, 0.0, 0.0, 0.0],
[392.0, 0.0, 0.0, 0.0],
[112.0, -0.25, -0.25, -0.5],
];
const DCT4_LLF_PARAMS: [[f64; 2]; 3] = [
[1.0, 1.0],
[1.0, 1.0],
[1.0, 1.0],
];
const AFV_WEIGHTS: [[f64; 9]; 3] = [
[3072.0, 3072.0, 256.0, 256.0, 256.0, 414.0, 0.0, 0.0, 0.0],
[1024.0, 1024.0, 50.0, 50.0, 50.0, 58.0, 0.0, 0.0, 0.0],
[384.0, 384.0, 12.0, 12.0, 12.0, 22.0, -0.25, -0.25, -0.25],
];
const AFV_FREQS: [f64; 16] = [
0.0, 0.0, 0.8517778890324296, 5.37778436506804, 0.0, 0.0, 4.734747904497923, 5.449245381693219, 1.6598270267479331, 4.0, 7.275749096817861, 10.423227632456525, 2.662932286148962, 7.630657783650829, 8.962388608184032, 12.97166202570235, ];
static QUANT_WEIGHTS_DCT8: OnceBox<Vec<f32>> = OnceBox::new();
fn quant_weights_dct8() -> &'static [f32] {
QUANT_WEIGHTS_DCT8.get_or_init(|| {
Box::new(generate_dct_quant_weights_rect(
8,
8,
&[&DCT8_PARAMS[0], &DCT8_PARAMS[1], &DCT8_PARAMS[2]],
6,
))
})
}
static QUANT_WEIGHTS_DCT16X16: OnceBox<Vec<f32>> = OnceBox::new();
fn quant_weights_dct16x16() -> &'static [f32] {
QUANT_WEIGHTS_DCT16X16.get_or_init(|| {
Box::new(generate_dct_quant_weights_rect(
16,
16,
&[
&DCT16X16_PARAMS[0],
&DCT16X16_PARAMS[1],
&DCT16X16_PARAMS[2],
],
7,
))
})
}
static QUANT_WEIGHTS_DCT16X8: OnceBox<Vec<f32>> = OnceBox::new();
fn quant_weights_dct16x8() -> &'static [f32] {
QUANT_WEIGHTS_DCT16X8.get_or_init(|| {
Box::new(generate_dct_quant_weights_rect(
8,
16,
&[&DCT16X8_PARAMS[0], &DCT16X8_PARAMS[1], &DCT16X8_PARAMS[2]],
7,
))
})
}
static QUANT_WEIGHTS_DCT32X32: OnceBox<Vec<f32>> = OnceBox::new();
fn quant_weights_dct32x32() -> &'static [f32] {
QUANT_WEIGHTS_DCT32X32.get_or_init(|| {
Box::new(generate_dct_quant_weights_rect(
32,
32,
&[
&DCT32X32_BAND_PARAMS[0],
&DCT32X32_BAND_PARAMS[1],
&DCT32X32_BAND_PARAMS[2],
],
8,
))
})
}
static QUANT_WEIGHTS_DCT16X32: OnceBox<Vec<f32>> = OnceBox::new();
fn quant_weights_dct16x32() -> &'static [f32] {
QUANT_WEIGHTS_DCT16X32.get_or_init(|| {
Box::new(generate_dct_quant_weights_rect(
16,
32,
&[
&DCT16X32_BAND_PARAMS[0],
&DCT16X32_BAND_PARAMS[1],
&DCT16X32_BAND_PARAMS[2],
],
8,
))
})
}
static QUANT_WEIGHTS_DCT64X64: OnceBox<Vec<f32>> = OnceBox::new();
fn quant_weights_dct64x64() -> &'static [f32] {
QUANT_WEIGHTS_DCT64X64.get_or_init(|| {
Box::new(generate_dct_quant_weights_rect(
64,
64,
&[
&DCT64X64_BAND_PARAMS[0],
&DCT64X64_BAND_PARAMS[1],
&DCT64X64_BAND_PARAMS[2],
],
8,
))
})
}
static QUANT_WEIGHTS_DCT32X64: OnceBox<Vec<f32>> = OnceBox::new();
fn quant_weights_dct32x64() -> &'static [f32] {
QUANT_WEIGHTS_DCT32X64.get_or_init(|| {
Box::new(generate_dct_quant_weights_rect(
32,
64,
&[
&DCT32X64_BAND_PARAMS[0],
&DCT32X64_BAND_PARAMS[1],
&DCT32X64_BAND_PARAMS[2],
],
8,
))
})
}
fn generate_dct4x8_weights() -> Vec<f32> {
let mut weights = Vec::with_capacity(192);
let sqrt2 = core::f64::consts::SQRT_2;
for params in &DCT4X8_BAND_PARAMS {
let mut bands = vec![params[0]];
let mut last = params[0];
for &v in ¶ms[1..] {
last *= band_mult(v);
bands.push(last);
}
let width = 8usize;
let height = 4usize;
let mut mat_8x4 = vec![0.0f64; width * height];
for y in 0..height {
let dy = y as f64 / (height - 1).max(1) as f64;
for x in 0..width {
let dx = x as f64 / (width - 1).max(1) as f64;
let distance = (dx * dx + dy * dy).sqrt();
let scaled = distance * (bands.len() - 1) as f64 / (sqrt2 + 1e-6);
let weight = interpolate_band(scaled, &bands);
mat_8x4[y * width + x] = weight;
}
}
for row in 0..height {
for x in 0..width {
weights.push((1.0 / mat_8x4[row * width + x]) as f32);
}
for x in 0..width {
weights.push((1.0 / mat_8x4[row * width + x]) as f32);
}
}
}
weights
}
static QUANT_WEIGHTS_DCT4X8: OnceBox<Vec<f32>> = OnceBox::new();
fn quant_weights_dct4x8() -> &'static [f32] {
QUANT_WEIGHTS_DCT4X8.get_or_init(|| Box::new(generate_dct4x8_weights()))
}
static QUANT_WEIGHTS_DCT8X4: OnceBox<Vec<f32>> = OnceBox::new();
fn quant_weights_dct8x4() -> &'static [f32] {
QUANT_WEIGHTS_DCT8X4.get_or_init(|| Box::new(generate_dct4x8_weights()))
}
fn generate_dct4x4_weights() -> Vec<f32> {
let mut weights = Vec::with_capacity(192);
let sqrt2 = core::f64::consts::SQRT_2;
for (c, params) in DCT4_BAND_PARAMS.iter().enumerate() {
let mut bands = vec![params[0]];
let mut last = params[0];
for &v in ¶ms[1..] {
last *= band_mult(v);
bands.push(last);
}
let size = 4usize;
let mut mat_4x4 = vec![0.0f64; size * size];
for y in 0..size {
let dy = y as f64 / (size - 1).max(1) as f64;
for x in 0..size {
let dx = x as f64 / (size - 1).max(1) as f64;
let distance = (dx * dx + dy * dy).sqrt();
let scaled = distance * (bands.len() - 1) as f64 / (sqrt2 + 1e-6);
let weight = interpolate_band(scaled, &bands);
mat_4x4[y * size + x] = weight;
}
}
let mut channel_weights = vec![0.0f64; 64];
for y in 0..4 {
for x in 0..4 {
let w = mat_4x4[y * 4 + x];
channel_weights[y * 16 + x * 2] = w;
channel_weights[y * 16 + x * 2 + 1] = w;
channel_weights[(y * 2 + 1) * 8 + x * 2] = w;
channel_weights[(y * 2 + 1) * 8 + x * 2 + 1] = w;
}
}
channel_weights[1] /= DCT4_LLF_PARAMS[c][0];
channel_weights[8] /= DCT4_LLF_PARAMS[c][0];
channel_weights[9] /= DCT4_LLF_PARAMS[c][1];
for w in &channel_weights {
weights.push((1.0 / w) as f32);
}
}
weights
}
static QUANT_WEIGHTS_DCT4X4: OnceBox<Vec<f32>> = OnceBox::new();
fn quant_weights_dct4x4() -> &'static [f32] {
QUANT_WEIGHTS_DCT4X4.get_or_init(|| Box::new(generate_dct4x4_weights()))
}
const IDENTITY_WEIGHTS: [[f32; 3]; 3] = [
[280.0, 3160.0, 3160.0], [60.0, 864.0, 864.0], [18.0, 200.0, 200.0], ];
fn generate_identity_weights() -> Vec<f32> {
let mut weights = vec![0.0f32; 3 * 64];
for (c, ch_weights) in IDENTITY_WEIGHTS.iter().enumerate() {
let start = c * 64;
let dequant0 = ch_weights[0];
let dequant1 = ch_weights[1];
let dequant2 = ch_weights[2];
for w in &mut weights[start..start + 64] {
*w = 1.0 / dequant0;
}
weights[start + 1] = 1.0 / dequant1;
weights[start + 8] = 1.0 / dequant1;
weights[start + 9] = 1.0 / dequant2;
}
weights
}
static QUANT_WEIGHTS_IDENTITY: OnceBox<Vec<f32>> = OnceBox::new();
fn quant_weights_identity() -> &'static [f32] {
QUANT_WEIGHTS_IDENTITY.get_or_init(|| Box::new(generate_identity_weights()))
}
const DCT2_WEIGHTS: [[f32; 6]; 3] = [
[3840.0, 2560.0, 1280.0, 640.0, 480.0, 300.0], [960.0, 640.0, 320.0, 180.0, 140.0, 120.0], [640.0, 320.0, 128.0, 64.0, 32.0, 16.0], ];
fn generate_dct2x2_weights() -> Vec<f32> {
let mut weights = vec![0.0f32; 3 * 64];
for (c, band_weights) in DCT2_WEIGHTS.iter().enumerate() {
let start = c * 64;
let w = band_weights;
weights[start] = 1.0 / 0xBAD as f32;
weights[start + 1] = 1.0 / w[0];
weights[start + 8] = 1.0 / w[0];
weights[start + 9] = 1.0 / w[1];
for y in 0..2usize {
for x in 0..2usize {
weights[start + y * 8 + x + 2] = 1.0 / w[2];
weights[start + (y + 2) * 8 + x] = 1.0 / w[2];
}
}
for y in 0..2usize {
for x in 0..2usize {
weights[start + (y + 2) * 8 + x + 2] = 1.0 / w[3];
}
}
for y in 0..4usize {
for x in 0..4usize {
weights[start + y * 8 + x + 4] = 1.0 / w[4];
weights[start + (y + 4) * 8 + x] = 1.0 / w[4];
}
}
for y in 0..4usize {
for x in 0..4usize {
weights[start + (y + 4) * 8 + x + 4] = 1.0 / w[5];
}
}
}
weights
}
static QUANT_WEIGHTS_DCT2X2: OnceBox<Vec<f32>> = OnceBox::new();
fn quant_weights_dct2x2() -> &'static [f32] {
QUANT_WEIGHTS_DCT2X2.get_or_init(|| Box::new(generate_dct2x2_weights()))
}
fn generate_afv_weights() -> Vec<f32> {
let mut weights = vec![0.0f32; 192];
let weights4x8 = generate_dct4x8_weights();
let weights4x4 = generate_dct4x4_weights();
const LO: f64 = 0.8517778890324296;
const HI: f64 = 12.97166202570235 - LO + 1e-6;
for (c, afv) in AFV_WEIGHTS.iter().enumerate() {
let start = c * 64;
let mut bands = [0.0f64; 4];
bands[0] = afv[5]; for i in 1..4 {
bands[i] = bands[i - 1] * band_mult(afv[5 + i]);
}
weights[start] = (1.0 / bands[0]) as f32;
weights[start + 1] = (1.0 / afv[0]) as f32; weights[start + 8] = (1.0 / afv[1]) as f32;
weights[start + 2] = (1.0 / afv[2]) as f32; weights[start + 16] = (1.0 / afv[3]) as f32; weights[start + 18] = (1.0 / afv[4]) as f32;
for y in 0..4usize {
for x in 0..4usize {
if x < 2 && y < 2 {
continue; }
let freq = AFV_FREQS[y * 4 + x];
let val = interpolate_band((freq - LO) / HI * 3.0, &bands);
weights[start + (2 * y) * 8 + (2 * x)] = (1.0 / val) as f32;
}
}
for y in 0..4usize {
for x in 0..8usize {
if x == 0 && y == 0 {
continue; }
let idx4x8 = c * 64 + y * 16 + x; weights[start + (2 * y + 1) * 8 + x] = weights4x8[idx4x8];
}
}
for y in 0..4usize {
for x in 0..4usize {
if x == 0 && y == 0 {
continue; }
let idx4x4 = c * 64 + y * 16 + x * 2; weights[start + (2 * y) * 8 + (2 * x + 1)] = weights4x4[idx4x4];
}
}
}
weights
}
static QUANT_WEIGHTS_AFV: OnceBox<Vec<f32>> = OnceBox::new();
fn quant_weights_afv() -> &'static [f32] {
QUANT_WEIGHTS_AFV.get_or_init(|| Box::new(generate_afv_weights()))
}
pub(super) const WEIGHT_SIZES: [usize; NUM_VALID_STRATEGIES] = [
64, 128, 128, 256, 1024, 64, 64, 64, 64, 64, 512, 512, 64, 64, 64, 64, 4096, 2048, 2048,
];
#[inline]
pub(super) fn quant_weights_full(strategy: usize) -> &'static [f32] {
match strategy {
0 => quant_weights_dct8(),
1 | 2 => quant_weights_dct16x8(),
3 => quant_weights_dct16x16(),
4 => quant_weights_dct32x32(),
5 => quant_weights_dct4x8(),
6 => quant_weights_dct8x4(),
7 => quant_weights_dct4x4(),
8 => quant_weights_identity(),
9 => quant_weights_dct2x2(),
10 | 11 => quant_weights_dct16x32(),
12..=15 => quant_weights_afv(),
16 => quant_weights_dct64x64(),
17 | 18 => quant_weights_dct32x64(),
_ => unreachable!("Invalid strategy: {}", strategy),
}
}
#[inline]
pub fn quant_weights(strategy: usize, channel: usize) -> &'static [f32] {
debug_assert!(strategy < NUM_VALID_STRATEGIES);
debug_assert!(channel < 3);
let per_ch = WEIGHT_SIZES[strategy];
let offset = channel * per_ch;
&quant_weights_full(strategy)[offset..offset + per_ch]
}
#[inline]
#[allow(dead_code)]
pub fn inv_quant_weight(strategy: usize, channel: usize, coeff_idx: usize) -> f32 {
let weights = quant_weights(strategy, channel);
debug_assert!(coeff_idx < weights.len());
1.0 / weights[coeff_idx]
}
fn generate_dequant_weights(strategy: usize) -> Vec<f32> {
let per_ch = WEIGHT_SIZES[strategy];
let mut out = Vec::with_capacity(3 * per_ch);
for c in 0..3 {
let w = quant_weights(strategy, c);
for &v in w {
out.push(1.0 / v);
}
}
out
}
static DEQUANT_WEIGHTS_DCT8: OnceBox<Vec<f32>> = OnceBox::new();
static DEQUANT_WEIGHTS_DCT16X8: OnceBox<Vec<f32>> = OnceBox::new();
static DEQUANT_WEIGHTS_DCT16X16: OnceBox<Vec<f32>> = OnceBox::new();
static DEQUANT_WEIGHTS_DCT32X32: OnceBox<Vec<f32>> = OnceBox::new();
static DEQUANT_WEIGHTS_DCT4X8: OnceBox<Vec<f32>> = OnceBox::new();
static DEQUANT_WEIGHTS_DCT8X4: OnceBox<Vec<f32>> = OnceBox::new();
static DEQUANT_WEIGHTS_DCT4X4: OnceBox<Vec<f32>> = OnceBox::new();
static DEQUANT_WEIGHTS_IDENTITY: OnceBox<Vec<f32>> = OnceBox::new();
static DEQUANT_WEIGHTS_DCT2X2: OnceBox<Vec<f32>> = OnceBox::new();
static DEQUANT_WEIGHTS_DCT16X32: OnceBox<Vec<f32>> = OnceBox::new();
static DEQUANT_WEIGHTS_AFV: OnceBox<Vec<f32>> = OnceBox::new();
static DEQUANT_WEIGHTS_DCT64X64: OnceBox<Vec<f32>> = OnceBox::new();
static DEQUANT_WEIGHTS_DCT32X64: OnceBox<Vec<f32>> = OnceBox::new();
#[inline]
pub(super) fn dequant_weights_full(strategy: usize) -> &'static [f32] {
match strategy {
0 => DEQUANT_WEIGHTS_DCT8.get_or_init(|| Box::new(generate_dequant_weights(0))),
1 | 2 => DEQUANT_WEIGHTS_DCT16X8.get_or_init(|| Box::new(generate_dequant_weights(1))),
3 => DEQUANT_WEIGHTS_DCT16X16.get_or_init(|| Box::new(generate_dequant_weights(3))),
4 => DEQUANT_WEIGHTS_DCT32X32.get_or_init(|| Box::new(generate_dequant_weights(4))),
5 => DEQUANT_WEIGHTS_DCT4X8.get_or_init(|| Box::new(generate_dequant_weights(5))),
6 => DEQUANT_WEIGHTS_DCT8X4.get_or_init(|| Box::new(generate_dequant_weights(6))),
7 => DEQUANT_WEIGHTS_DCT4X4.get_or_init(|| Box::new(generate_dequant_weights(7))),
8 => DEQUANT_WEIGHTS_IDENTITY.get_or_init(|| Box::new(generate_dequant_weights(8))),
9 => DEQUANT_WEIGHTS_DCT2X2.get_or_init(|| Box::new(generate_dequant_weights(9))),
10 | 11 => DEQUANT_WEIGHTS_DCT16X32.get_or_init(|| Box::new(generate_dequant_weights(10))),
12..=15 => DEQUANT_WEIGHTS_AFV.get_or_init(|| Box::new(generate_dequant_weights(12))),
16 => DEQUANT_WEIGHTS_DCT64X64.get_or_init(|| Box::new(generate_dequant_weights(16))),
17 | 18 => DEQUANT_WEIGHTS_DCT32X64.get_or_init(|| Box::new(generate_dequant_weights(17))),
_ => unreachable!("Invalid strategy: {}", strategy),
}
}
#[inline]
pub fn dequant_weights(strategy: usize, channel: usize) -> &'static [f32] {
debug_assert!(strategy < NUM_VALID_STRATEGIES);
debug_assert!(channel < 3);
let per_ch = WEIGHT_SIZES[strategy];
let offset = channel * per_ch;
&dequant_weights_full(strategy)[offset..offset + per_ch]
}
#[inline]
#[allow(dead_code)]
pub fn quantize_coeff(
coeff: f32,
strategy: usize,
channel: usize,
coeff_idx: usize,
global_scale: f32,
) -> i32 {
let weight = quant_weights(strategy, channel)[coeff_idx];
let q = coeff * global_scale / weight;
q.round() as i32
}
#[inline]
#[allow(dead_code)]
pub fn dequantize_coeff(
qcoeff: i32,
strategy: usize,
channel: usize,
coeff_idx: usize,
inv_global_scale: f32,
) -> f32 {
let weight = quant_weights(strategy, channel)[coeff_idx];
(qcoeff as f32) * weight * inv_global_scale
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_table_sizes() {
assert_eq!(quant_weights_dct8().len(), 192);
assert_eq!(quant_weights_dct16x8().len(), 384);
assert_eq!(quant_weights_dct16x16().len(), 768);
assert_eq!(quant_weights_dct32x32().len(), 3072);
}
#[test]
fn test_dc_quant_inverse() {
for c in 0..3 {
let product = DC_QUANT[c] * INV_DC_QUANT[c];
assert!(
(product - 1.0).abs() < 1e-6,
"DC_QUANT[{}] * INV_DC_QUANT[{}] = {} != 1.0",
c,
c,
product
);
}
}
#[test]
fn test_quant_weights_access() {
for c in 0..3 {
assert_eq!(quant_weights(0, c).len(), 64);
}
for strategy in 1..3 {
for c in 0..3 {
assert_eq!(quant_weights(strategy, c).len(), 128);
}
}
for c in 0..3 {
assert_eq!(quant_weights(3, c).len(), 256);
}
for c in 0..3 {
assert_eq!(quant_weights(4, c).len(), 1024);
}
for strategy in 5..8 {
for c in 0..3 {
assert_eq!(quant_weights(strategy, c).len(), 64);
}
}
}
#[test]
fn test_all_weights_positive() {
let strategies = [
(0, "DCT8", 64),
(1, "DCT16X8", 128),
(2, "DCT8X16", 128),
(3, "DCT16X16", 256),
(4, "DCT32X32", 1024),
(5, "DCT4X8", 64),
(6, "DCT8X4", 64),
(7, "DCT4X4", 64),
];
for &(strat, name, expected_len) in &strategies {
for c in 0..3 {
let w = quant_weights(strat, c);
assert_eq!(w.len(), expected_len, "{} ch={} wrong length", name, c);
for (i, &val) in w.iter().enumerate() {
assert!(
val > 0.0,
"{} weight[ch={}, {}] = {} should be positive",
name,
c,
i,
val
);
}
}
}
}
#[test]
fn test_dc_smallest_weight() {
let strategies = [(0, "DCT8"), (3, "DCT16X16"), (4, "DCT32X32")];
for &(strat, name) in &strategies {
for c in 0..3 {
let w = quant_weights(strat, c);
let dc = w[0];
for (i, &val) in w.iter().enumerate().skip(1) {
assert!(
val >= dc * 0.99, "{} weight[ch={}, {}] = {} is less than DC = {}",
name,
c,
i,
val,
dc
);
}
}
}
}
#[test]
fn test_quantize_dequantize_roundtrip() {
let global_scale = 1.0;
let inv_scale = 1.0;
let test_values = [1.0f32, -1.0, 100.0, -100.0, 0.001, -0.001];
for &val in &test_values {
let q = quantize_coeff(val, 0, 0, 0, global_scale);
let dq = dequantize_coeff(q, 0, 0, 0, inv_scale);
let weight = quant_weights(0, 0)[0];
let expected_error = weight / 2.0; assert!(
(dq - val).abs() <= expected_error + 1e-6,
"Roundtrip error too large: {} -> {} -> {}, weight={}",
val,
q,
dq,
weight
);
}
}
#[test]
fn test_weight_ranges() {
for strat in 0..NUM_VALID_STRATEGIES {
for c in 0..3 {
let w = quant_weights(strat, c);
let min_weight = w.iter().cloned().fold(f32::INFINITY, f32::min);
let max_weight = w.iter().cloned().fold(f32::NEG_INFINITY, f32::max);
assert!(
min_weight > 1e-7,
"strat={} ch={}: min weight {} too small",
strat,
c,
min_weight
);
assert!(
max_weight < 1.0,
"strat={} ch={}: max weight {} too large",
strat,
c,
max_weight
);
}
}
}
#[test]
fn test_dct4x8_position8_weight() {
let channels = ["X", "Y", "B"];
for (c, ch_name) in channels.iter().enumerate() {
let w = quant_weights(5, c); eprintln!(
"DCT4X8 {}: pos0={:.6}, pos8={:.6}, ratio={:.6}",
ch_name,
w[0],
w[8],
w[0] / w[8]
);
let dct8_w = quant_weights(0, c);
eprintln!(
" DCT8 {}: pos0={:.6}, pos8={:.6}",
ch_name, dct8_w[0], dct8_w[8]
);
}
for c in 0..3 {
let dct4x8 = quant_weights(5, c);
let dct8 = quant_weights(0, c);
let ratio = dct4x8[8] / dct8[0];
assert!(
(0.1..10.0).contains(&ratio),
"DCT4X8[8] / DCT8[0] ratio for channel {} is out of range: {} ({}:{})",
c,
ratio,
dct4x8[8],
dct8[0]
);
}
}
#[test]
fn test_dct16_vs_dct8_equivalent_frequencies() {
let channels = ["X", "Y", "B"];
for (c, ch_name) in channels.iter().enumerate() {
let w8 = quant_weights(0, c);
let w16 = quant_weights(3, c);
eprintln!(
"=== Channel {} equivalent frequency dequant weights ===",
ch_name
);
eprintln!(
"{:>10} {:>12} {:>12} {:>8}",
"DCT8(y,x)", "DCT8_dequant", "DCT16_dequant", "ratio"
);
for y8 in 0..8 {
for x8 in 0..8 {
let idx8 = y8 * 8 + x8;
let y16 = y8 * 2;
let x16 = x8 * 2;
let idx16 = y16 * 16 + x16;
let dequant8 = 1.0 / w8[idx8];
let dequant16 = 1.0 / w16[idx16];
let ratio = dequant16 / dequant8;
if y8 < 4 && x8 < 4 {
eprintln!(
" ({},{}) {:>12.2} {:>12.2} {:>8.3}",
y8, x8, dequant8, dequant16, ratio
);
}
}
}
let mut total_ratio = 0.0f64;
for y8 in 0..8 {
for x8 in 0..8 {
let dequant8 = 1.0 / w8[y8 * 8 + x8] as f64;
let dequant16 = 1.0 / w16[y8 * 2 * 16 + x8 * 2] as f64;
total_ratio += dequant16 / dequant8;
}
}
eprintln!(
" Average dequant16/dequant8 ratio: {:.4}",
total_ratio / 64.0
);
}
eprintln!("\n=== DCT16 extra frequencies (Y channel) ===");
let w16 = quant_weights(3, 1);
let mut extra_dequant_sum = 0.0f64;
let mut extra_count = 0;
for y in 0..16 {
for x in 0..16 {
if y % 2 != 0 || x % 2 != 0 {
let dequant = 1.0 / w16[y * 16 + x] as f64;
extra_dequant_sum += dequant;
extra_count += 1;
}
}
}
let w8 = quant_weights(0, 1);
let mut base_dequant_sum = 0.0f64;
for y in 0..8 {
for x in 0..8 {
base_dequant_sum += 1.0 / w8[y * 8 + x] as f64;
}
}
eprintln!(
"DCT8 avg dequant: {:.2}, DCT16 extra freq avg dequant: {:.2}",
base_dequant_sum / 64.0,
extra_dequant_sum / extra_count as f64
);
}
#[test]
fn test_weight_stats_per_strategy() {
let strategies = [
(0, "DCT8"),
(1, "DCT16x8"),
(2, "DCT8x16"),
(3, "DCT16x16"),
(4, "DCT32x32"),
(5, "DCT4X8"),
(6, "DCT8X4"),
];
let channels = ["X", "Y", "B"];
for &(strat, name) in &strategies {
for (c, ch_name) in channels.iter().enumerate() {
let w = quant_weights(strat, c);
let min_w = w.iter().cloned().fold(f32::INFINITY, f32::min);
let max_w = w.iter().cloned().fold(f32::NEG_INFINITY, f32::max);
let mean_w: f32 = w.iter().sum::<f32>() / w.len() as f32;
let max_inv = 1.0 / min_w;
let min_inv = 1.0 / max_w;
eprintln!(
"{:>8} ch={}: {} coeffs, weight range [{:.6}, {:.6}], mean={:.6}, inv range [{:.1}, {:.1}]",
name,
ch_name,
w.len(),
min_w,
max_w,
mean_w,
min_inv,
max_inv
);
}
}
}
}
#[cfg(test)]
mod weight_debug_tests {
use super::*;
#[test]
fn test_print_dct8_weights() {
let w = quant_weights(0, 0); println!("DCT8 X channel quant_weights()[0..8]:");
for (i, &wi) in w.iter().enumerate().take(8) {
println!(" [{}] = {:.6e} (reciprocal = {:.6e})", i, wi, 1.0 / wi);
}
let w_y = quant_weights(0, 1); println!("DCT8 Y channel quant_weights()[0..8]:");
for (i, &wi) in w_y.iter().enumerate().take(8) {
println!(" [{}] = {:.6e} (reciprocal = {:.6e})", i, wi, 1.0 / wi);
}
}
}