use super::compat::TrellisConfig;
use crate::foundation::consts::{DCT_BLOCK_SIZE, JPEG_NATURAL_ORDER};
use super::rate::RateTable;
#[inline]
pub(crate) fn jpeg_nbits(val: i16) -> u8 {
let abs = val.unsigned_abs();
if abs == 0 {
0
} else {
16 - abs.leading_zeros() as u8
}
}
#[allow(clippy::needless_range_loop)]
pub fn trellis_quantize_block(
src: &[i32; DCT_BLOCK_SIZE],
quantized: &mut [i16; DCT_BLOCK_SIZE],
qtable: &[u16; DCT_BLOCK_SIZE],
ac_table: &RateTable,
config: &TrellisConfig,
) {
let mut lambda_tbl = [0.0f32; DCT_BLOCK_SIZE];
for i in 0..DCT_BLOCK_SIZE {
let q = qtable[i] as f32;
lambda_tbl[i] = 1.0 / (q * q);
}
let mut norm: f32 = 0.0;
for i in 1..DCT_BLOCK_SIZE {
let c = src[i] as f32;
norm += c * c;
}
norm /= 63.0;
let lambda = if config.lambda_log_scale2 > 0.0 {
let scale1 = 2.0_f32.powf(config.lambda_log_scale1);
let scale2 = 2.0_f32.powf(config.lambda_log_scale2);
scale1 / (scale2 + norm)
} else {
2.0_f32.powf(config.lambda_log_scale1 - 12.0)
};
let mut accumulated_zero_dist = [0.0f32; DCT_BLOCK_SIZE];
let mut accumulated_cost = [0.0f32; DCT_BLOCK_SIZE];
let mut run_start = [0usize; DCT_BLOCK_SIZE];
const MAX_COEF_VAL: i32 = (1 << 10) - 1; {
let x = src[0].abs();
let sign = if src[0] < 0 { -1i16 } else { 1i16 };
let q = 8 * qtable[0] as i32;
let qval = ((x + q / 2) / q).min(MAX_COEF_VAL);
quantized[0] = (qval as i16) * sign;
}
accumulated_zero_dist[0] = 0.0;
accumulated_cost[0] = 0.0;
let (max_lookback, max_candidates) = {
let mut nonzero_count = 0i32;
for i in 1..DCT_BLOCK_SIZE {
let z = JPEG_NATURAL_ORDER[i] as usize;
let x = src[z].abs();
let q = 8 * qtable[z] as i32;
if (x + q / 2) / q > 0 {
nonzero_count += 1;
}
}
config.speed_mode.get_limits(nonzero_count)
};
for i in 1..DCT_BLOCK_SIZE {
let z = JPEG_NATURAL_ORDER[i] as usize;
let x = src[z].abs();
let sign = if src[z] < 0 { -1i16 } else { 1i16 };
let q = 8 * qtable[z] as i32;
let zero_dist = (x as f32).powi(2) * lambda * lambda_tbl[z];
accumulated_zero_dist[i] = zero_dist + accumulated_zero_dist[i - 1];
let qval = (x + q / 2) / q;
if qval == 0 {
quantized[z] = 0;
accumulated_cost[i] = f32::MAX;
run_start[i] = i - 1;
continue;
}
let qval = qval.min(1023);
let num_candidates = (jpeg_nbits(qval as i16) as usize).min(max_candidates);
let mut candidates = [(0i32, 0u8, 0.0f32); 16];
for k in 0..num_candidates {
let candidate_val = if k < num_candidates - 1 {
(2 << k) - 1 } else {
qval
};
let delta = candidate_val * q - x;
let dist = (delta as f32).powi(2) * lambda * lambda_tbl[z];
candidates[k] = (candidate_val, (k + 1) as u8, dist);
}
accumulated_cost[i] = f32::MAX;
let j_start = i.saturating_sub(max_lookback);
for j in j_start..i {
let zz = JPEG_NATURAL_ORDER[j] as usize;
if j != 0 && quantized[zz] == 0 {
continue;
}
let zero_run = i - 1 - j;
let zrl_cost = if zero_run >= 16 {
let (_, zrl_size) = ac_table.get_code(0xF0);
if zrl_size == 0 {
continue;
}
(zero_run / 16) * zrl_size as usize
} else {
0
};
let run_mod_16 = zero_run & 15;
for k in 0..num_candidates {
let (candidate_val, candidate_bits, candidate_dist) = candidates[k];
let symbol = ((run_mod_16 as u8) << 4) | candidate_bits;
let (_, code_size) = ac_table.get_code(symbol);
if code_size == 0 {
continue;
}
let rate = code_size as usize + candidate_bits as usize + zrl_cost;
let zero_run_dist = accumulated_zero_dist[i - 1] - accumulated_zero_dist[j];
let prev_cost = if j == 0 { 0.0 } else { accumulated_cost[j] };
let cost = rate as f32 + candidate_dist + zero_run_dist + prev_cost;
if cost < accumulated_cost[i] {
quantized[z] = (candidate_val as i16) * sign;
accumulated_cost[i] = cost;
run_start[i] = j;
}
}
}
}
let eob_cost = {
let (_, eob_size) = ac_table.get_code(0x00);
eob_size as f32
};
let mut best_cost = accumulated_zero_dist[DCT_BLOCK_SIZE - 1] + eob_cost;
let mut last_coeff_idx = 0;
for i in 1..DCT_BLOCK_SIZE {
let z = JPEG_NATURAL_ORDER[i] as usize;
if quantized[z] != 0 {
let tail_zero_dist =
accumulated_zero_dist[DCT_BLOCK_SIZE - 1] - accumulated_zero_dist[i];
let mut cost = accumulated_cost[i] + tail_zero_dist;
if i < DCT_BLOCK_SIZE - 1 {
cost += eob_cost;
}
if cost < best_cost {
best_cost = cost;
last_coeff_idx = i;
}
}
}
let mut i = DCT_BLOCK_SIZE - 1;
while i >= 1 {
while i > last_coeff_idx {
let z = JPEG_NATURAL_ORDER[i] as usize;
quantized[z] = 0;
i -= 1;
}
if i >= 1 {
last_coeff_idx = run_start[i];
i -= 1;
}
}
}
#[cfg(test)]
mod tests {
use super::*;
use crate::encode::trellis::TrellisSpeedMode;
fn create_ac_table() -> RateTable {
RateTable::standard_luma_ac()
}
fn create_qtable() -> [u16; DCT_BLOCK_SIZE] {
#[rustfmt::skip]
let table: [u16; 64] = [
16, 11, 10, 16, 24, 40, 51, 61,
12, 12, 14, 19, 26, 58, 60, 55,
14, 13, 16, 24, 40, 57, 69, 56,
14, 17, 22, 29, 51, 87, 80, 62,
18, 22, 37, 56, 68, 109, 103, 77,
24, 35, 55, 64, 81, 104, 113, 92,
49, 64, 78, 87, 103, 121, 120, 101,
72, 92, 95, 98, 112, 100, 103, 99,
];
table
}
#[test]
fn test_trellis_quantize_zero_block() {
let ac_table = create_ac_table();
let qtable = create_qtable();
let config = TrellisConfig::default();
let src = [0i32; DCT_BLOCK_SIZE];
let mut quantized = [0i16; DCT_BLOCK_SIZE];
trellis_quantize_block(&src, &mut quantized, &qtable, &ac_table, &config);
for &q in quantized.iter() {
assert_eq!(q, 0);
}
}
#[test]
fn test_trellis_quantize_dc_only() {
let ac_table = create_ac_table();
let qtable = create_qtable();
let config = TrellisConfig::default();
let mut src = [0i32; DCT_BLOCK_SIZE];
src[0] = 1000 * 8;
let mut quantized = [0i16; DCT_BLOCK_SIZE];
trellis_quantize_block(&src, &mut quantized, &qtable, &ac_table, &config);
assert!(quantized[0] > 0);
for i in 1..DCT_BLOCK_SIZE {
assert_eq!(quantized[i], 0);
}
}
#[test]
fn test_trellis_preserves_large_coefficients() {
let ac_table = create_ac_table();
let qtable = create_qtable();
let config = TrellisConfig::default();
let mut src = [0i32; DCT_BLOCK_SIZE];
src[0] = 500 * 8;
src[1] = 200 * 8;
let mut quantized = [0i16; DCT_BLOCK_SIZE];
trellis_quantize_block(&src, &mut quantized, &qtable, &ac_table, &config);
assert!(quantized[0] != 0);
}
#[test]
fn test_trellis_negative_coefficients() {
let ac_table = create_ac_table();
let qtable = create_qtable();
let config = TrellisConfig::default();
let mut src = [0i32; DCT_BLOCK_SIZE];
src[0] = -1000 * 8;
src[1] = -200 * 8;
let mut quantized = [0i16; DCT_BLOCK_SIZE];
trellis_quantize_block(&src, &mut quantized, &qtable, &ac_table, &config);
assert!(quantized[0] < 0);
}
#[test]
fn test_speed_modes() {
let ac_table = create_ac_table();
let qtable = create_qtable();
let mut src = [0i32; DCT_BLOCK_SIZE];
for (i, s) in src.iter_mut().enumerate() {
*s = ((i as i32 + 1) * 50) * 8;
}
for mode in [
TrellisSpeedMode::Thorough,
TrellisSpeedMode::Adaptive,
TrellisSpeedMode::Level(5),
TrellisSpeedMode::Level(10),
] {
let config = TrellisConfig::default().speed_mode(mode);
let mut quantized = [0i16; DCT_BLOCK_SIZE];
trellis_quantize_block(&src, &mut quantized, &qtable, &ac_table, &config);
assert!(
quantized[0] != 0,
"DC should be non-zero for mode {:?}",
mode
);
}
}
#[test]
fn test_jpeg_nbits() {
assert_eq!(jpeg_nbits(0), 0);
assert_eq!(jpeg_nbits(1), 1);
assert_eq!(jpeg_nbits(-1), 1);
assert_eq!(jpeg_nbits(2), 2);
assert_eq!(jpeg_nbits(3), 2);
assert_eq!(jpeg_nbits(4), 3);
assert_eq!(jpeg_nbits(7), 3);
assert_eq!(jpeg_nbits(8), 4);
assert_eq!(jpeg_nbits(255), 8);
assert_eq!(jpeg_nbits(1023), 10);
}
}