use super::EncoderError;
const MF: [[i32; 3]; 6] = [
[13107, 5243, 8066],
[11916, 4660, 7490],
[10082, 4194, 6554],
[9362, 3647, 5825],
[8192, 3355, 5243],
[7282, 2893, 4559],
];
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
pub enum QuantSlice {
Intra,
Inter,
}
#[derive(Debug, Clone, Copy)]
pub struct QuantParams {
pub qp: u8,
pub slice: QuantSlice,
}
#[inline]
const fn norm_class(i: usize, j: usize) -> usize {
let even_i = i & 1 == 0;
let even_j = j & 1 == 0;
if even_i && even_j {
0
} else if !even_i && !even_j {
1
} else {
2
}
}
const Q_BITS_BASE: u32 = 15;
#[inline]
const fn q_bits(qp: u8) -> u32 {
Q_BITS_BASE + (qp as u32 / 6)
}
#[inline]
const fn f_offset(slice: QuantSlice, qbits: u32) -> i64 {
let one = 1i64 << qbits;
match slice {
QuantSlice::Intra => one / 3,
QuantSlice::Inter => one / 6,
}
}
pub fn forward_quantize_4x4(
coeffs: &[[i32; 4]; 4],
params: QuantParams,
) -> [[i32; 4]; 4] {
debug_assert!(params.qp <= 51, "qp out of range: {}", params.qp);
let qbits = q_bits(params.qp);
let f = f_offset(params.slice, qbits);
let mf_row = MF[(params.qp % 6) as usize];
let mut levels = [[0i32; 4]; 4];
for i in 0..4 {
for j in 0..4 {
let c = coeffs[i][j] as i64;
let mf = mf_row[norm_class(i, j)] as i64;
let mag = c.unsigned_abs() as i64;
let level_mag = (mag * mf + f) >> qbits;
levels[i][j] = if c < 0 {
-(level_mag as i32)
} else {
level_mag as i32
};
}
}
levels
}
pub fn forward_quantize_dc_luma(
dc_hadamard: &[[i32; 4]; 4],
qp: u8,
slice: QuantSlice,
) -> [[i32; 4]; 4] {
debug_assert!(qp <= 51, "qp out of range: {qp}");
let qbits = q_bits(qp) + 2;
let f = f_offset(slice, qbits);
let mf = MF[(qp % 6) as usize][0] as i64;
let mut levels = [[0i32; 4]; 4];
for i in 0..4 {
for j in 0..4 {
let c = dc_hadamard[i][j] as i64;
let mag = c.unsigned_abs() as i64;
let level_mag = (mag * mf + f) >> qbits;
levels[i][j] = if c < 0 {
-(level_mag as i32)
} else {
level_mag as i32
};
}
}
levels
}
pub fn forward_quantize_dc_chroma(
dc_hadamard: &[[i32; 2]; 2],
qp_c: u8,
slice: QuantSlice,
) -> [[i32; 2]; 2] {
debug_assert!(qp_c <= 51, "chroma qp out of range: {qp_c}");
let qbits = q_bits(qp_c) + 1;
let f = f_offset(slice, qbits);
let mf = MF[(qp_c % 6) as usize][0] as i64;
let mut levels = [[0i32; 2]; 2];
for i in 0..2 {
for j in 0..2 {
let c = dc_hadamard[i][j] as i64;
let mag = c.unsigned_abs() as i64;
let level_mag = (mag * mf + f) >> qbits;
levels[i][j] = if c < 0 {
-(level_mag as i32)
} else {
level_mag as i32
};
}
}
levels
}
pub fn trellis_quantize_4x4(
coeffs: &[[i32; 4]; 4],
params: QuantParams,
enable: bool,
) -> Result<[[i32; 4]; 4], EncoderError> {
let mut levels = forward_quantize_4x4(coeffs, params);
if !enable {
return Ok(levels);
}
let qp = params.qp as i32;
let q_mod = (qp.rem_euclid(6)) as usize;
let qbits = q_bits(params.qp) as i32;
let lambda2_q8: i64 = super::rdo::lambda2_for_qp(params.qp) as i64;
let mult_q10: i64 = std::env::var("PHASM_TRELLIS_LAMBDA_MULT")
.ok()
.and_then(|s| s.parse().ok())
.unwrap_or(1024);
let lambda2_scaled: i64 = (lambda2_q8 * mult_q10) >> 10;
const ZIGZAG_POSITIONS: [(usize, usize); 16] = [
(0, 0), (0, 1), (1, 0), (2, 0), (1, 1), (0, 2), (0, 3), (1, 2),
(2, 1), (3, 0), (3, 1), (2, 2), (1, 3), (2, 3), (3, 2), (3, 3),
];
const BITS_PER_COEFF: i64 = 4;
for zig_idx in (0..16).rev() {
let (i, j) = ZIGZAG_POSITIONS[zig_idx];
let level = levels[i][j];
if level == 0 {
continue;
}
let c = coeffs[i][j] as i64;
let mf = MF[q_mod][norm_class(i, j)] as i64;
let c_hat = (level.unsigned_abs() as i64) * (1i64 << qbits) / mf;
let c_hat_signed = if level < 0 { -c_hat } else { c_hat };
let d_keep = c - c_hat_signed;
let d_keep_sq = d_keep * d_keep;
let d_drop_sq = c * c;
let dist_increase_coef = d_drop_sq - d_keep_sq;
let dist_increase_px = dist_increase_coef >> 4;
let rate_gain = (BITS_PER_COEFF * lambda2_scaled) >> 8;
if rate_gain >= dist_increase_px {
levels[i][j] = 0;
}
}
Ok(levels)
}
#[cfg(test)]
mod tests {
use super::*;
use crate::codec::h264::transform::{
dequant_4x4, inverse_4x4_integer,
};
fn intra(qp: u8) -> QuantParams {
QuantParams {
qp,
slice: QuantSlice::Intra,
}
}
fn inter(qp: u8) -> QuantParams {
QuantParams {
qp,
slice: QuantSlice::Inter,
}
}
#[test]
fn quant_zero_in_zero_out() {
let zero = [[0i32; 4]; 4];
for qp in [0, 12, 24, 36, 51] {
let q = forward_quantize_4x4(&zero, intra(qp));
for row in &q {
assert_eq!(row, &[0, 0, 0, 0], "qp={qp}");
}
}
}
#[test]
fn quant_small_coeff_falls_in_deadzone_at_high_qp() {
let mut x = [[0i32; 4]; 4];
for i in 0..4 {
for j in 0..4 {
if (i, j) != (0, 0) {
x[i][j] = 80;
}
}
}
let q = forward_quantize_4x4(&x, intra(36));
for i in 0..4 {
for j in 0..4 {
if (i, j) != (0, 0) {
assert_eq!(q[i][j], 0, "qp=36 deadzone at ({i},{j})");
}
}
}
}
#[test]
fn quant_higher_qp_more_zeros() {
let x = [
[400, 80, 60, 40],
[80, 60, 40, 30],
[60, 40, 30, 20],
[40, 30, 20, 10],
];
let zeros = |q: &[[i32; 4]; 4]| -> usize {
q.iter().flatten().filter(|&&v| v == 0).count()
};
let zero_lo = zeros(&forward_quantize_4x4(&x, intra(12)));
let zero_hi = zeros(&forward_quantize_4x4(&x, intra(36)));
assert!(
zero_hi >= zero_lo,
"qp=36 should produce ≥ zeros than qp=12 ({zero_hi} vs {zero_lo})"
);
assert!(
zero_hi > zero_lo,
"qp=36 didn't produce strictly more zeros than qp=12"
);
}
#[test]
fn quant_intra_deadzone_wider_than_inter() {
let x = [
[50, 30, 20, 10],
[30, 25, 15, 8],
[20, 15, 10, 5],
[10, 8, 5, 3],
];
let zeros = |q: &[[i32; 4]; 4]| -> usize {
q.iter().flatten().filter(|&&v| v == 0).count()
};
let qp = 24;
let z_intra = zeros(&forward_quantize_4x4(&x, intra(qp)));
let z_inter = zeros(&forward_quantize_4x4(&x, inter(qp)));
assert!(
z_intra >= z_inter,
"intra dead-zone wider than inter ({z_intra} vs {z_inter})"
);
}
#[test]
fn quant_sign_preserved() {
let x = [
[800, -800, 600, -600],
[-400, 400, -300, 300],
[200, -200, 100, -100],
[-90, 90, -80, 80],
];
let q = forward_quantize_4x4(&x, intra(12));
for i in 0..4 {
for j in 0..4 {
if q[i][j] != 0 {
assert_eq!(
q[i][j].signum(),
x[i][j].signum(),
"sign flip at ({i},{j}): {} → {}",
x[i][j],
q[i][j],
);
}
}
}
}
#[test]
fn quant_dequant_round_trip_pipeline_runs() {
use crate::codec::h264::encoder::transform::forward_dct_4x4;
let x = [
[120, -80, 40, -20],
[-100, 60, -30, 15],
[80, -50, 25, -12],
[-60, 40, -20, 10],
];
for qp in [12u8, 22, 36] {
let y = forward_dct_4x4(&x);
let q = forward_quantize_4x4(&y, intra(qp));
let dq = dequant_4x4(&q, qp as i32, false);
let _recovered = inverse_4x4_integer(&dq);
let avg_mag: i32 = q.iter().flatten().map(|v| v.abs()).sum();
assert!(avg_mag >= 0, "qp={qp} levels weren't computed");
}
}
#[test]
fn quant_dc_luma_zero_in_zero_out() {
let zero = [[0i32; 4]; 4];
for qp in [0, 12, 24, 36, 51] {
let q = forward_quantize_dc_luma(&zero, qp, QuantSlice::Intra);
for row in &q {
assert_eq!(row, &[0, 0, 0, 0], "qp={qp}");
}
}
}
#[test]
fn quant_dc_luma_sign_preserved() {
let dc = [
[400, -400, 200, -200],
[-300, 300, -150, 150],
[200, -200, 100, -100],
[-50, 50, -30, 30],
];
let q = forward_quantize_dc_luma(&dc, 12, QuantSlice::Intra);
for i in 0..4 {
for j in 0..4 {
if q[i][j] != 0 {
assert_eq!(
q[i][j].signum(),
dc[i][j].signum(),
"DC sign flip at ({i},{j})",
);
}
}
}
}
#[test]
fn quant_dc_chroma_zero_in_zero_out() {
let zero = [[0i32; 2]; 2];
for qp_c in [0, 12, 24, 36, 39] {
let q = forward_quantize_dc_chroma(&zero, qp_c, QuantSlice::Intra);
assert_eq!(q, [[0, 0], [0, 0]], "qp_c={qp_c}");
}
}
#[test]
fn quant_dc_chroma_sign_preserved() {
let dc = [[200, -150], [80, -60]];
let q = forward_quantize_dc_chroma(&dc, 12, QuantSlice::Intra);
for i in 0..2 {
for j in 0..2 {
if q[i][j] != 0 {
assert_eq!(
q[i][j].signum(),
dc[i][j].signum(),
"chroma DC sign flip at ({i},{j})",
);
}
}
}
}
#[test]
fn trellis_with_zero_lambda_matches_scalar() {
let x = [[100, -50, 25, 0]; 4];
let scalar = forward_quantize_4x4(&x, intra(22));
let trellis = trellis_quantize_4x4(&x, intra(22), false).unwrap();
assert_eq!(trellis, scalar, "lambda=0 must equal scalar");
}
#[test]
fn trellis_zeroes_low_magnitude_levels_at_high_lambda() {
let mut x = [[0i32; 4]; 4];
x[3][3] = 200; let scalar = forward_quantize_4x4(&x, intra(22));
let trellis = trellis_quantize_4x4(&x, intra(22), true).unwrap();
let scalar_nonzero = scalar.iter().flatten().filter(|&&v| v != 0).count();
let trellis_nonzero = trellis.iter().flatten().filter(|&&v| v != 0).count();
assert!(
trellis_nonzero <= scalar_nonzero,
"trellis must not add nonzero levels: {trellis_nonzero} > {scalar_nonzero}"
);
}
#[test]
fn trellis_preserves_high_magnitude_levels() {
let mut x = [[0i32; 4]; 4];
x[0][0] = 4000;
let trellis = trellis_quantize_4x4(&x, intra(22), true).unwrap();
assert!(
trellis[0][0] != 0,
"large coefficient should survive trellis"
);
}
#[test]
fn trellis_lambda_shim_returns_positive() {
for qp in 0..=51u8 {
let _ = qp;
}
}
#[test]
fn trellis_uses_spec_lambda2_table() {
unsafe { std::env::set_var("PHASM_TRELLIS_LAMBDA_MULT", "65536"); }
let mut x = [[0i32; 4]; 4];
x[0][0] = 10_000; x[3][3] = 131; let trellis = trellis_quantize_4x4(&x, inter(28), true).unwrap();
unsafe { std::env::remove_var("PHASM_TRELLIS_LAMBDA_MULT"); }
assert!(trellis[0][0] != 0, "huge DC must survive spec trellis");
assert_eq!(
trellis[3][3], 0,
"tiny high-freq AC should be dropped by aggressive trellis"
);
}
#[test]
fn trellis_with_default_mult_is_near_neutral() {
unsafe { std::env::remove_var("PHASM_TRELLIS_LAMBDA_MULT"); }
let x = [
[800, 200, 100, 50],
[200, 150, 80, 30],
[100, 80, 50, 20],
[50, 30, 20, 10],
];
let scalar = forward_quantize_4x4(&x, inter(22));
let trellis = trellis_quantize_4x4(&x, inter(22), true).unwrap();
let scalar_nonzero = scalar.iter().flatten().filter(|&&v| v != 0).count();
let trellis_nonzero = trellis.iter().flatten().filter(|&&v| v != 0).count();
assert!(
trellis_nonzero >= scalar_nonzero.saturating_sub(2),
"default MULT should drop at most ~1-2 tiny AC levels, \
got {} → {}",
scalar_nonzero, trellis_nonzero
);
}
}