use alloc::collections::BTreeMap;
use super::generate::TrialScan;
use crate::foundation::consts::DCT_BLOCK_SIZE;
use crate::huffman::optimize::FrequencyCounter;
use crate::huffman::optimize::cluster::cluster_histograms;
type ScanKey = (u8, u8, u8, u8, u8);
type CachedHistogram = (FrequencyCounter, usize);
pub(crate) struct ScanHistogramCache<'a> {
cache: BTreeMap<ScanKey, CachedHistogram>,
y_blocks: &'a [[i16; DCT_BLOCK_SIZE]],
cb_blocks: &'a [[i16; DCT_BLOCK_SIZE]],
cr_blocks: &'a [[i16; DCT_BLOCK_SIZE]],
}
impl<'a> ScanHistogramCache<'a> {
pub fn warm(
trial_scans: &[TrialScan],
scripts: &[Vec<super::super::config::ProgressiveScan>],
y_blocks: &'a [[i16; DCT_BLOCK_SIZE]],
cb_blocks: &'a [[i16; DCT_BLOCK_SIZE]],
cr_blocks: &'a [[i16; DCT_BLOCK_SIZE]],
num_components: u8,
) -> Self {
let mut keys = alloc::collections::BTreeSet::<ScanKey>::new();
for scan in trial_scans {
if scan.is_dc() {
if scan.comps_in_scan > 1 {
for c in 0..num_components {
keys.insert((c, 0, 0, 0, 0));
}
} else {
keys.insert((scan.component, 0, 0, 0, 0));
}
} else {
keys.insert((scan.component, scan.ss, scan.se, scan.ah, scan.al));
}
}
for script in scripts {
for scan in script {
if scan.ss == 0 && scan.se == 0 {
for &comp in &scan.components {
keys.insert((comp, 0, 0, 0, 0));
}
} else {
keys.insert((scan.components[0], scan.ss, scan.se, scan.ah, scan.al));
}
}
}
let mut cache = BTreeMap::new();
for key in keys {
let (comp, ss, se, ah, al) = key;
let blocks = match comp {
0 => y_blocks,
1 => cb_blocks,
2 => cr_blocks,
_ => continue,
};
let entry = if ss == 0 && se == 0 {
estimate_dc_scan_detailed(blocks)
} else if ah == 0 {
estimate_ac_first_scan_detailed(blocks, ss, se, al)
} else {
estimate_ac_refinement_scan_detailed(blocks, ss, se, ah, al)
};
cache.insert(key, entry);
}
Self {
cache,
y_blocks,
cb_blocks,
cr_blocks,
}
}
fn get(&mut self, key: &ScanKey) -> &CachedHistogram {
if !self.cache.contains_key(key) {
let (comp, ss, se, ah, al) = *key;
let blocks = match comp {
0 => self.y_blocks,
1 => self.cb_blocks,
_ => self.cr_blocks,
};
let entry = if ss == 0 && se == 0 {
estimate_dc_scan_detailed(blocks)
} else if ah == 0 {
estimate_ac_first_scan_detailed(blocks, ss, se, al)
} else {
estimate_ac_refinement_scan_detailed(blocks, ss, se, ah, al)
};
self.cache.insert(*key, entry);
}
self.cache.get(key).unwrap()
}
pub fn estimate_all_scan_sizes_cached(&mut self, scans: &[TrialScan]) -> Vec<usize> {
let mut sizes = Vec::with_capacity(scans.len());
for scan in scans {
let estimated = if scan.is_dc() {
if scan.comps_in_scan > 1 {
let mut total = 0usize;
for c in 0..scan.comps_in_scan {
let (counter, extra) = self.get(&(c, 0, 0, 0, 0));
total += counter.estimate_encoding_cost() as usize + extra;
}
total
} else {
let (counter, extra) = self.get(&(scan.component, 0, 0, 0, 0));
counter.estimate_encoding_cost() as usize + extra
}
} else {
let key = (scan.component, scan.ss, scan.se, scan.ah, scan.al);
let (counter, extra) = self.get(&key);
counter.estimate_encoding_cost() as usize + extra
};
sizes.push(estimated);
}
sizes
}
pub fn estimate_script_cost_cached(
&mut self,
script: &[super::super::config::ProgressiveScan],
) -> usize {
let mut dc_histograms: Vec<FrequencyCounter> = Vec::new();
let mut ac_histograms: Vec<FrequencyCounter> = Vec::new();
let mut extra_bits_total: usize = 0;
for scan in script {
if scan.ss == 0 && scan.se == 0 {
for &comp in &scan.components {
let (counter, extra) = self.get(&(comp, 0, 0, 0, 0));
dc_histograms.push(counter.clone());
extra_bits_total += extra;
}
} else {
let key = (scan.components[0], scan.ss, scan.se, scan.ah, scan.al);
let (counter, extra) = self.get(&key);
ac_histograms.push(counter.clone());
extra_bits_total += extra;
}
}
let mut huffman_total: usize = 0;
if !dc_histograms.is_empty() {
let dc_result = cluster_histograms(&dc_histograms, 4, false);
for h in &dc_result.cluster_histograms {
huffman_total += h.estimate_encoding_cost() as usize;
}
}
if !ac_histograms.is_empty() {
let ac_result = cluster_histograms(&ac_histograms, 32, false);
for h in &ac_result.cluster_histograms {
huffman_total += h.estimate_encoding_cost() as usize;
}
}
huffman_total + extra_bits_total + script.len() * SOS_HEADER_BITS
}
}
pub(crate) fn estimate_all_scan_sizes(
scans: &[TrialScan],
y_blocks: &[[i16; DCT_BLOCK_SIZE]],
cb_blocks: &[[i16; DCT_BLOCK_SIZE]],
cr_blocks: &[[i16; DCT_BLOCK_SIZE]],
) -> Vec<usize> {
let mut counter = FrequencyCounter::new();
let mut sizes = Vec::with_capacity(scans.len());
for scan in scans {
counter.reset();
let blocks = match scan.component {
0 => y_blocks,
1 => cb_blocks,
2 => cr_blocks,
_ => &[],
};
let estimated = if scan.is_dc() {
if scan.comps_in_scan > 1 {
estimate_dc_scan(&mut counter, y_blocks)
+ estimate_dc_scan(&mut counter, cb_blocks)
+ estimate_dc_scan(&mut counter, cr_blocks)
} else {
estimate_dc_scan(&mut counter, blocks)
}
} else if scan.ah == 0 {
estimate_ac_first_scan(&mut counter, blocks, scan.ss, scan.se, scan.al)
} else {
estimate_ac_refinement_scan(blocks, scan.ss, scan.se, scan.ah, scan.al)
};
sizes.push(estimated);
}
sizes
}
fn estimate_dc_scan(counter: &mut FrequencyCounter, blocks: &[[i16; DCT_BLOCK_SIZE]]) -> usize {
counter.reset();
if blocks.is_empty() {
return 0;
}
let mut prev_dc = 0i16;
for block in blocks {
let dc = block[0];
let diff = dc.wrapping_sub(prev_dc);
prev_dc = dc;
let category = dc_category(diff);
counter.count(category);
}
let huffman_cost = counter.estimate_encoding_cost();
let extra_bits: f64 = blocks
.iter()
.scan(0i16, |prev, block| {
let diff = block[0].wrapping_sub(*prev);
*prev = block[0];
Some(dc_category(diff) as f64)
})
.sum();
(huffman_cost + extra_bits) as usize
}
fn estimate_ac_first_scan(
counter: &mut FrequencyCounter,
blocks: &[[i16; DCT_BLOCK_SIZE]],
ss: u8,
se: u8,
al: u8,
) -> usize {
counter.reset();
if blocks.is_empty() {
return 0;
}
let ss = ss as usize;
let se = se as usize;
let mut eob_run = 0u32;
let mut eob_extra_bits = 0usize;
for block in blocks {
let mut run = 0u8;
let mut block_has_nonzero = false;
for k in ss..=se {
let abs_coeff = block[k].unsigned_abs() >> al;
if abs_coeff == 0 {
run += 1;
continue;
}
if !block_has_nonzero && eob_run > 0 {
eob_extra_bits += count_eob_run(counter, eob_run);
eob_run = 0;
}
block_has_nonzero = true;
while run >= 16 {
counter.count(0xF0); run -= 16;
}
let size = ac_category(abs_coeff);
let symbol = (run << 4) | size;
counter.count(symbol);
run = 0;
}
if run > 0 {
eob_run += 1;
if eob_run >= 32767 {
eob_extra_bits += count_eob_run(counter, eob_run);
eob_run = 0;
}
}
}
if eob_run > 0 {
eob_extra_bits += count_eob_run(counter, eob_run);
}
let value_extra_bits: f64 = blocks
.iter()
.map(|block| {
let mut bits = 0.0f64;
for k in ss..=se {
let abs_coeff = block[k].unsigned_abs() >> al;
if abs_coeff > 0 {
bits += ac_category(abs_coeff) as f64; }
}
bits
})
.sum();
(counter.estimate_encoding_cost() + value_extra_bits) as usize + eob_extra_bits
}
fn estimate_ac_refinement_scan(
blocks: &[[i16; DCT_BLOCK_SIZE]],
ss: u8,
se: u8,
ah: u8,
al: u8,
) -> usize {
if blocks.is_empty() {
return 0;
}
let ss = ss as usize;
let se = se as usize;
let mut counter = FrequencyCounter::new();
let mut total_refbits = 0usize;
let mut eob_run = 0u32;
let mut eob_extra_bits = 0usize;
for block in blocks {
let mut run = 0u8;
let mut block_has_newly_sig = false;
for k in ss..=se {
let coeff = block[k];
let abs_coeff = coeff.unsigned_abs();
let prev_nonzero = (abs_coeff >> ah) > 0;
let cur_bit = (abs_coeff >> al) & 1;
if prev_nonzero {
total_refbits += 1;
} else if cur_bit != 0 {
if !block_has_newly_sig && eob_run > 0 {
eob_extra_bits += count_eob_run(&mut counter, eob_run);
eob_run = 0;
}
block_has_newly_sig = true;
while run >= 16 {
counter.count(0xF0); run -= 16;
}
let symbol = (run << 4) | 1;
counter.count(symbol);
total_refbits += 1; run = 0;
} else {
run += 1;
}
}
if !block_has_newly_sig || run > 0 {
eob_run += 1;
if eob_run >= 32767 {
eob_extra_bits += count_eob_run(&mut counter, eob_run);
eob_run = 0;
}
}
}
if eob_run > 0 {
eob_extra_bits += count_eob_run(&mut counter, eob_run);
}
let huffman_cost = counter.estimate_encoding_cost();
(huffman_cost as usize) + total_refbits + eob_extra_bits
}
const SOS_HEADER_BITS: usize = 84;
const SCAN_OVERHEAD_BITS: usize = 150;
pub(crate) fn estimate_script_cost(
script: &[super::super::config::ProgressiveScan],
y_blocks: &[[i16; DCT_BLOCK_SIZE]],
cb_blocks: &[[i16; DCT_BLOCK_SIZE]],
cr_blocks: &[[i16; DCT_BLOCK_SIZE]],
) -> usize {
let mut dc_histograms: Vec<FrequencyCounter> = Vec::new();
let mut ac_histograms: Vec<FrequencyCounter> = Vec::new();
let mut extra_bits_total: usize = 0;
for scan in script {
if scan.ss == 0 && scan.se == 0 {
for &comp in &scan.components {
let blocks = match comp {
0 => y_blocks,
1 => cb_blocks,
2 => cr_blocks,
_ => continue,
};
let (counter, extra) = estimate_dc_scan_detailed(blocks);
dc_histograms.push(counter);
extra_bits_total += extra;
}
} else if scan.ah == 0 {
let blocks = match scan.components[0] {
0 => y_blocks,
1 => cb_blocks,
2 => cr_blocks,
_ => continue,
};
let (counter, extra) =
estimate_ac_first_scan_detailed(blocks, scan.ss, scan.se, scan.al);
ac_histograms.push(counter);
extra_bits_total += extra;
} else {
let blocks = match scan.components[0] {
0 => y_blocks,
1 => cb_blocks,
2 => cr_blocks,
_ => continue,
};
let (counter, extra) =
estimate_ac_refinement_scan_detailed(blocks, scan.ss, scan.se, scan.ah, scan.al);
ac_histograms.push(counter);
extra_bits_total += extra;
}
}
let mut huffman_total: usize = 0;
if !dc_histograms.is_empty() {
let dc_result = cluster_histograms(&dc_histograms, 4, false);
for h in &dc_result.cluster_histograms {
huffman_total += h.estimate_encoding_cost() as usize;
}
}
if !ac_histograms.is_empty() {
let ac_result = cluster_histograms(&ac_histograms, 32, false);
for h in &ac_result.cluster_histograms {
huffman_total += h.estimate_encoding_cost() as usize;
}
}
huffman_total + extra_bits_total + script.len() * SOS_HEADER_BITS
}
fn estimate_dc_scan_detailed(blocks: &[[i16; DCT_BLOCK_SIZE]]) -> (FrequencyCounter, usize) {
let mut counter = FrequencyCounter::new();
if blocks.is_empty() {
return (counter, 0);
}
let mut prev_dc = 0i16;
let mut extra_bits: usize = 0;
for block in blocks {
let dc = block[0];
let diff = dc.wrapping_sub(prev_dc);
prev_dc = dc;
let category = dc_category(diff);
counter.count(category);
extra_bits += category as usize;
}
(counter, extra_bits)
}
fn estimate_ac_first_scan_detailed(
blocks: &[[i16; DCT_BLOCK_SIZE]],
ss: u8,
se: u8,
al: u8,
) -> (FrequencyCounter, usize) {
let mut counter = FrequencyCounter::new();
if blocks.is_empty() {
return (counter, 0);
}
let ss = ss as usize;
let se = se as usize;
let mut eob_run = 0u32;
let mut eob_extra_bits = 0usize;
for block in blocks {
let mut run = 0u8;
let mut block_has_nonzero = false;
for k in ss..=se {
let abs_coeff = block[k].unsigned_abs() >> al;
if abs_coeff == 0 {
run += 1;
continue;
}
if !block_has_nonzero && eob_run > 0 {
eob_extra_bits += count_eob_run(&mut counter, eob_run);
eob_run = 0;
}
block_has_nonzero = true;
while run >= 16 {
counter.count(0xF0);
run -= 16;
}
let size = ac_category(abs_coeff);
let symbol = (run << 4) | size;
counter.count(symbol);
run = 0;
}
if run > 0 {
eob_run += 1;
if eob_run >= 32767 {
eob_extra_bits += count_eob_run(&mut counter, eob_run);
eob_run = 0;
}
}
}
if eob_run > 0 {
eob_extra_bits += count_eob_run(&mut counter, eob_run);
}
let value_extra_bits: usize = blocks
.iter()
.map(|block| {
let mut bits = 0usize;
for k in ss..=se {
let abs_coeff = block[k].unsigned_abs() >> al;
if abs_coeff > 0 {
bits += ac_category(abs_coeff) as usize;
}
}
bits
})
.sum();
(counter, value_extra_bits + eob_extra_bits)
}
fn estimate_ac_refinement_scan_detailed(
blocks: &[[i16; DCT_BLOCK_SIZE]],
ss: u8,
se: u8,
ah: u8,
al: u8,
) -> (FrequencyCounter, usize) {
let mut counter = FrequencyCounter::new();
if blocks.is_empty() {
return (counter, 0);
}
let ss = ss as usize;
let se = se as usize;
let mut total_refbits = 0usize;
let mut eob_run = 0u32;
let mut eob_extra_bits = 0usize;
for block in blocks {
let mut run = 0u8;
let mut block_has_newly_sig = false;
for k in ss..=se {
let coeff = block[k];
let abs_coeff = coeff.unsigned_abs();
let prev_nonzero = (abs_coeff >> ah) > 0;
let cur_bit = (abs_coeff >> al) & 1;
if prev_nonzero {
total_refbits += 1;
} else if cur_bit != 0 {
if !block_has_newly_sig && eob_run > 0 {
eob_extra_bits += count_eob_run(&mut counter, eob_run);
eob_run = 0;
}
block_has_newly_sig = true;
while run >= 16 {
counter.count(0xF0);
run -= 16;
}
let symbol = (run << 4) | 1;
counter.count(symbol);
total_refbits += 1; run = 0;
} else {
run += 1;
}
}
if !block_has_newly_sig || run > 0 {
eob_run += 1;
if eob_run >= 32767 {
eob_extra_bits += count_eob_run(&mut counter, eob_run);
eob_run = 0;
}
}
}
if eob_run > 0 {
eob_extra_bits += count_eob_run(&mut counter, eob_run);
}
(counter, total_refbits + eob_extra_bits)
}
fn count_eob_run(counter: &mut FrequencyCounter, n: u32) -> usize {
debug_assert!(n > 0);
let category = 31 - n.leading_zeros();
let symbol = (category as u8) << 4;
counter.count(symbol);
category as usize
}
#[inline]
fn dc_category(diff: i16) -> u8 {
if diff == 0 {
return 0;
}
let abs_diff = diff.unsigned_abs();
16 - abs_diff.leading_zeros() as u8
}
#[inline]
fn ac_category(abs_val: u16) -> u8 {
if abs_val == 0 {
return 0;
}
16 - abs_val.leading_zeros() as u8
}
#[cfg(test)]
mod tests {
use super::*;
use crate::encode::scan_optimize::ScanSearchConfig;
use crate::encode::scan_optimize::generate::generate_search_scans;
#[test]
fn test_dc_category() {
assert_eq!(dc_category(0), 0);
assert_eq!(dc_category(1), 1);
assert_eq!(dc_category(-1), 1);
assert_eq!(dc_category(2), 2);
assert_eq!(dc_category(3), 2);
assert_eq!(dc_category(4), 3);
assert_eq!(dc_category(7), 3);
assert_eq!(dc_category(-7), 3);
assert_eq!(dc_category(255), 8);
}
#[test]
fn test_ac_category() {
assert_eq!(ac_category(0), 0);
assert_eq!(ac_category(1), 1);
assert_eq!(ac_category(2), 2);
assert_eq!(ac_category(3), 2);
assert_eq!(ac_category(255), 8);
assert_eq!(ac_category(1023), 10);
}
#[test]
fn test_eob_run_encoding() {
let mut counter = FrequencyCounter::new();
assert_eq!(count_eob_run(&mut counter, 1), 0);
counter.reset();
assert_eq!(count_eob_run(&mut counter, 2), 1);
counter.reset();
assert_eq!(count_eob_run(&mut counter, 4), 2);
counter.reset();
assert_eq!(count_eob_run(&mut counter, 100), 6);
counter.reset();
assert_eq!(count_eob_run(&mut counter, 32767), 14);
}
#[test]
fn test_eob_runs_cheaper_than_individual() {
let zero_blocks = vec![[0i16; 64]; 1000];
let mut counter = FrequencyCounter::new();
let cost_with_runs = estimate_ac_first_scan(&mut counter, &zero_blocks, 1, 63, 0);
assert!(
cost_with_runs < 300,
"1000 zero-block EOB run should be very cheap, got {}",
cost_with_runs
);
}
#[test]
fn test_estimate_zero_blocks() {
let config = ScanSearchConfig::default();
let scans = generate_search_scans(3, &config);
let zero_blocks = vec![[0i16; 64]; 100];
let sizes = estimate_all_scan_sizes(&scans, &zero_blocks, &zero_blocks, &zero_blocks);
assert_eq!(sizes.len(), 64);
for (i, &size) in sizes.iter().enumerate() {
assert!(
size < 1_000_000,
"Scan {} has unreasonably large size: {}",
i,
size
);
}
}
#[test]
fn test_estimate_produces_valid_sizes() {
let config = ScanSearchConfig::default();
let scans = generate_search_scans(3, &config);
let mut y_blocks = vec![[0i16; 64]; 64];
let mut cb_blocks = vec![[0i16; 64]; 64];
let cr_blocks = vec![[0i16; 64]; 64];
for (i, block) in y_blocks.iter_mut().enumerate() {
block[0] = (i as i16) * 10; block[1] = 5; block[2] = -3;
if i % 4 == 0 {
block[10] = 2;
block[20] = -1;
}
}
for (i, block) in cb_blocks.iter_mut().enumerate() {
block[0] = (i as i16) * 5;
block[1] = 2;
}
let sizes = estimate_all_scan_sizes(&scans, &y_blocks, &cb_blocks, &cr_blocks);
assert_eq!(sizes.len(), 64);
assert!(sizes[0] > 0, "DC scan should have non-zero cost");
assert!(sizes[1] > 0, "Y AC 1-8 should have non-zero cost with data");
}
#[test]
fn test_dc_scan_monotonic_with_more_blocks() {
let mut counter = FrequencyCounter::new();
let small_blocks = vec![[10i16; 64]; 10];
let large_blocks = vec![[10i16; 64]; 100];
let small_cost = estimate_dc_scan(&mut counter, &small_blocks);
let large_cost = estimate_dc_scan(&mut counter, &large_blocks);
assert!(
large_cost >= small_cost,
"More blocks should cost at least as much: {} < {}",
large_cost,
small_cost
);
}
#[test]
fn test_refinement_has_refbits() {
let mut blocks = vec![[0i16; 64]; 10];
for block in blocks.iter_mut() {
block[1] = 2; block[2] = 3; }
let cost = estimate_ac_refinement_scan(&blocks, 1, 63, 1, 0);
assert!(cost > 0, "Refinement scan should have non-zero cost");
}
}