use crate::jpeg::quantize::ZIGZAG;
const DEFAULT_LAMBDA: f32 = 1.0;
const MAX_CANDIDATES: usize = 5;
const MAX_STATES: usize = 8;
#[derive(Clone, Copy)]
struct TrellisState {
cost: f32,
zero_run: u8,
parent: u16,
value: i16,
}
impl Default for TrellisState {
fn default() -> Self {
Self {
cost: f32::INFINITY,
zero_run: 0,
parent: 0,
value: 0,
}
}
}
pub fn trellis_quantize(
dct: &[f32; 64],
quant_table: &[f32; 64],
lambda: Option<f32>,
) -> [i16; 64] {
let lambda = lambda.unwrap_or(DEFAULT_LAMBDA);
let mut result = [0i16; 64];
result[0] = (dct[0] / quant_table[0]).round() as i16;
let mut ac_info: Vec<(usize, f32, f32)> = Vec::with_capacity(63);
for zz_pos in 1..64 {
let natural_pos = ZIGZAG[zz_pos];
ac_info.push((natural_pos, dct[natural_pos], quant_table[natural_pos]));
}
let mut current_states: Vec<TrellisState> = vec![TrellisState {
cost: 0.0,
zero_run: 0,
parent: 0,
value: 0,
}];
let mut all_states: Vec<Vec<TrellisState>> = Vec::with_capacity(64);
all_states.push(current_states.clone());
for &(_natural_pos, coef, q) in &ac_info {
let float_quant = coef / q;
let candidates = generate_candidates(float_quant);
let mut next_states: Vec<TrellisState> = Vec::with_capacity(MAX_STATES * MAX_CANDIDATES);
for (parent_idx, parent) in current_states.iter().enumerate() {
for &candidate in &candidates {
let reconstructed = candidate as f32 * q;
let distortion = (coef - reconstructed).powi(2);
let (rate, new_zero_run) = if candidate == 0 {
let new_run = parent.zero_run.saturating_add(1);
if new_run >= 16 {
(estimate_zrl_rate(), 0)
} else {
(0.0, new_run)
}
} else {
let rate = estimate_ac_rate(candidate, parent.zero_run);
(rate, 0)
};
let cost = parent.cost + rate + lambda * distortion;
let state = TrellisState {
cost,
zero_run: new_zero_run,
parent: parent_idx as u16,
value: candidate,
};
let existing = next_states
.iter_mut()
.find(|s| s.value == candidate && s.zero_run == new_zero_run);
match existing {
Some(s) if cost < s.cost => *s = state,
None => next_states.push(state),
_ => {}
}
}
}
next_states.sort_by(|a, b| {
a.cost
.partial_cmp(&b.cost)
.unwrap_or(std::cmp::Ordering::Equal)
});
next_states.truncate(MAX_STATES);
all_states.push(next_states.clone());
current_states = next_states;
if current_states.is_empty() {
break;
}
}
for state in &mut current_states {
if state.zero_run > 0 {
state.cost += estimate_eob_rate();
}
}
if let Some(best) = current_states.iter().min_by(|a, b| {
a.cost
.partial_cmp(&b.cost)
.unwrap_or(std::cmp::Ordering::Equal)
}) {
let mut path_values = [0i16; 63];
let mut state_idx = current_states
.iter()
.position(|s| std::ptr::eq(s, best))
.unwrap();
for zz_pos in (1..64).rev() {
let states = &all_states[zz_pos];
if state_idx < states.len() {
path_values[zz_pos - 1] = states[state_idx].value;
state_idx = states[state_idx].parent as usize;
}
}
for (zz_pos, &val) in path_values.iter().enumerate() {
let natural_pos = ZIGZAG[zz_pos + 1];
result[natural_pos] = val;
}
}
result
}
fn generate_candidates(float_quant: f32) -> Vec<i16> {
let rounded = float_quant.round() as i16;
let floor_val = float_quant.floor() as i16;
let ceil_val = float_quant.ceil() as i16;
let mut candidates = Vec::with_capacity(MAX_CANDIDATES);
candidates.push(0);
if floor_val != 0 && !candidates.contains(&floor_val) {
candidates.push(floor_val);
}
if rounded != 0 && !candidates.contains(&rounded) {
candidates.push(rounded);
}
if ceil_val != 0 && !candidates.contains(&ceil_val) {
candidates.push(ceil_val);
}
if float_quant.abs() > 1.5 {
let extended = if float_quant >= 0.0 {
ceil_val + 1
} else {
floor_val - 1
};
if !candidates.contains(&extended) {
candidates.push(extended);
}
}
candidates
}
fn estimate_ac_rate(value: i16, zero_run: u8) -> f32 {
let cat = category(value);
let rs = ((zero_run as usize) << 4) | (cat as usize);
let huffman_bits = estimate_ac_huffman_length(rs);
let value_bits = cat as f32;
huffman_bits + value_bits
}
fn estimate_ac_huffman_length(rs: usize) -> f32 {
match rs {
0x00 => 4.0, 0x01 => 2.0, 0x02 => 2.5, 0x03 => 3.0, 0x04 => 4.0, 0x11 => 3.0, 0x12 => 4.0, 0x21 => 4.0, 0xF0 => 10.0, _ => {
let run = (rs >> 4) as f32;
let size = (rs & 0x0F) as f32;
3.0 + run * 0.5 + size * 0.3
}
}
}
fn estimate_zrl_rate() -> f32 {
10.0 }
fn estimate_eob_rate() -> f32 {
4.0 }
fn category(value: i16) -> u8 {
let abs_val = value.unsigned_abs();
if abs_val == 0 {
0
} else {
16 - abs_val.leading_zeros() as u8
}
}
pub fn trellis_quantize_adaptive(
dct: &[f32; 64],
quant_table: &[f32; 64],
quality: u8,
) -> [i16; 64] {
let lambda = if quality >= 80 {
0.5 + (100 - quality) as f32 * 0.025 } else if quality >= 50 {
1.0 + (80 - quality) as f32 * 0.033 } else {
2.0 + (50 - quality) as f32 * 0.04 };
trellis_quantize(dct, quant_table, Some(lambda))
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_trellis_quantize_zeros() {
let dct = [0.0f32; 64];
let quant = [16.0f32; 64];
let result = trellis_quantize(&dct, &quant, None);
for &v in &result {
assert_eq!(v, 0);
}
}
#[test]
fn test_trellis_quantize_dc() {
let mut dct = [0.0f32; 64];
dct[0] = 800.0;
let mut quant = [16.0f32; 64];
quant[0] = 16.0;
let result = trellis_quantize(&dct, &quant, None);
assert_eq!(result[0], 50); }
#[test]
fn test_generate_candidates() {
let candidates = generate_candidates(5.3);
assert!(candidates.contains(&0)); assert!(candidates.contains(&5)); assert!(candidates.contains(&6));
let candidates = generate_candidates(-5.3);
assert!(candidates.contains(&0));
assert!(candidates.contains(&-5));
assert!(candidates.contains(&-6));
}
#[test]
fn test_generate_candidates_zero() {
let candidates = generate_candidates(0.1);
assert!(candidates.contains(&0));
assert!(!candidates.is_empty());
}
#[test]
fn test_category() {
assert_eq!(category(0), 0);
assert_eq!(category(1), 1);
assert_eq!(category(-1), 1);
assert_eq!(category(2), 2);
assert_eq!(category(3), 2);
assert_eq!(category(127), 7);
assert_eq!(category(-128), 8);
}
#[test]
fn test_trellis_sparsity() {
let mut dct = [0.0f32; 64];
for i in 1..64 {
dct[i] = 8.0; }
let quant = [16.0f32; 64];
let result = trellis_quantize(&dct, &quant, Some(2.0));
let zero_count = result.iter().skip(1).filter(|&&x| x == 0).count();
assert!(
zero_count > 30,
"Expected many zeros with high lambda, got {zero_count}"
);
}
#[test]
fn test_adaptive_lambda() {
let mut dct = [0.0f32; 64];
dct[0] = 800.0;
dct[1] = 50.0;
let quant = [16.0f32; 64];
let high_q = trellis_quantize_adaptive(&dct, &quant, 95);
let low_q = trellis_quantize_adaptive(&dct, &quant, 30);
assert_eq!(high_q[0], low_q[0]);
assert!(high_q[1].abs() <= 5);
assert!(low_q[1].abs() <= 5);
}
#[test]
fn test_category_edge_cases() {
assert_eq!(category(i16::MAX), 15);
assert_eq!(category(i16::MIN + 1), 15);
assert_eq!(category(16383), 14);
assert_eq!(category(-16384), 15);
}
#[test]
fn test_generate_candidates_negative() {
let candidates = generate_candidates(-3.7);
assert!(candidates.contains(&0));
assert!(candidates.contains(&-3) || candidates.contains(&-4));
}
#[test]
fn test_generate_candidates_exact_integer() {
let candidates = generate_candidates(5.0);
assert!(candidates.contains(&0));
assert!(candidates.contains(&5));
}
#[test]
fn test_generate_candidates_small() {
let candidates = generate_candidates(0.5);
assert!(candidates.contains(&0));
assert!(candidates.contains(&1) || candidates.is_empty());
}
#[test]
fn test_trellis_quantize_single_ac() {
let mut dct = [0.0f32; 64];
dct[0] = 400.0; dct[1] = 50.0;
let quant = [16.0f32; 64];
let result = trellis_quantize(&dct, &quant, None);
assert_eq!(result[0], 25); }
#[test]
fn test_trellis_quantize_preserves_dc() {
let mut dct = [0.0f32; 64];
dct[0] = 160.0;
let quant = [16.0f32; 64];
let result = trellis_quantize(&dct, &quant, Some(10.0));
assert_eq!(result[0], 10);
}
#[test]
fn test_trellis_quantize_high_frequency() {
let mut dct = [0.0f32; 64];
dct[0] = 200.0; dct[63] = 32.0;
dct[62] = 48.0;
let quant = [16.0f32; 64];
let result = trellis_quantize(&dct, &quant, None);
assert_eq!(result[0], 13); }
#[test]
fn test_trellis_quantize_near_threshold() {
let mut dct = [0.0f32; 64];
dct[0] = 160.0;
dct[1] = 8.1; dct[2] = 7.9;
let mut quant = [16.0f32; 64];
quant[1] = 16.0;
quant[2] = 16.0;
let result = trellis_quantize(&dct, &quant, None);
assert!(result[1].abs() <= 1);
assert!(result[2].abs() <= 1);
}
#[test]
fn test_estimate_ac_huffman_length_common_symbols() {
assert!(estimate_ac_huffman_length(0x00) <= 4.0); assert!(estimate_ac_huffman_length(0x01) <= 3.0); assert!(estimate_ac_huffman_length(0xF0) >= 8.0); }
#[test]
fn test_trellis_quantize_negative_coefficients() {
let mut dct = [0.0f32; 64];
dct[0] = -100.0;
dct[1] = -50.0;
dct[2] = 30.0;
let quant = [10.0f32; 64];
let result = trellis_quantize(&dct, &quant, None);
assert_eq!(result[0], -10); }
#[test]
fn test_adaptive_quality_boundaries() {
let mut dct = [0.0f32; 64];
dct[0] = 500.0;
let quant = [16.0f32; 64];
let _q1 = trellis_quantize_adaptive(&dct, &quant, 1);
let _q50 = trellis_quantize_adaptive(&dct, &quant, 50);
let _q80 = trellis_quantize_adaptive(&dct, &quant, 80);
let _q100 = trellis_quantize_adaptive(&dct, &quant, 100);
assert_eq!(_q1[0], 31); assert_eq!(_q100[0], 31);
}
#[test]
fn test_trellis_state_default() {
let state = TrellisState::default();
assert_eq!(state.cost, f32::INFINITY);
assert_eq!(state.zero_run, 0);
assert_eq!(state.parent, 0);
assert_eq!(state.value, 0);
}
#[test]
fn test_trellis_with_custom_lambda() {
let mut dct = [0.0f32; 64];
dct[0] = 800.0;
for i in 1..64 {
dct[i] = 10.0; }
let quant = [16.0f32; 64];
let result_low = trellis_quantize(&dct, &quant, Some(0.1));
let result_high = trellis_quantize(&dct, &quant, Some(10.0));
let nonzero_low: usize = result_low.iter().skip(1).filter(|&&x| x != 0).count();
let nonzero_high: usize = result_high.iter().skip(1).filter(|&&x| x != 0).count();
assert!(
nonzero_high <= nonzero_low,
"High lambda ({nonzero_high}) should be sparser than low ({nonzero_low})"
);
}
#[test]
fn test_trellis_zigzag_ordering() {
let mut dct = [0.0f32; 64];
dct[0] = 200.0; dct[1] = 50.0; dct[8] = 40.0;
let quant = [10.0f32; 64];
let result = trellis_quantize(&dct, &quant, Some(0.5));
assert_eq!(result[0], 20);
assert_eq!(result.len(), 64);
}
}