use alloc::vec::Vec;
use super::context::{self, ContextMode};
pub(crate) const NUM_CONTEXTS: usize = 64;
pub(crate) const MAX_LITERAL_TREES: usize = 16;
pub(crate) const CANDIDATE_MODES: [ContextMode; 4] = [
ContextMode::Utf8,
ContextMode::Msb6,
ContextMode::Lsb6,
ContextMode::Signed,
];
pub(crate) struct LiteralContextModel {
pub mode: ContextMode,
pub histograms: Vec<[u32; 256]>,
pub cmap: Vec<u8>,
pub num_trees: u32,
pub est_cost_bits: u64,
}
fn histogram_bits(hist: &[u32; 256], total: u32) -> u64 {
if total == 0 {
return 0;
}
let log_total = log2_fixed(total as u64);
let mut bits: u64 = 0;
for &c in hist.iter() {
if c != 0 {
bits += (c as u64) * (log_total - log2_fixed(c as u64));
}
}
bits
}
fn log2_fixed(x: u64) -> u64 {
debug_assert!(x >= 1);
if x == 1 {
return 0;
}
let floor = 63 - x.leading_zeros() as u64; let base = 1u64 << floor;
let frac = ((x - base) << 8) / base; floor * 256 + frac
}
fn merged_bits(a: &[u32; 256], at: u32, b: &[u32; 256], bt: u32) -> u64 {
let total = at + bt;
if total == 0 {
return 0;
}
let log_total = log2_fixed(total as u64);
let mut bits: u64 = 0;
for i in 0..256 {
let c = a[i] + b[i];
if c != 0 {
bits += (c as u64) * (log_total - log2_fixed(c as u64));
}
}
bits
}
const HEADER_COST_BITS: u64 = 140 * 256;
pub(crate) fn cluster(
mode: ContextMode,
mut histograms: Vec<[u32; 256]>,
max_trees: usize,
) -> LiteralContextModel {
debug_assert_eq!(histograms.len(), NUM_CONTEXTS);
let mut totals: Vec<u32> = histograms.iter().map(|h| h.iter().sum::<u32>()).collect();
let mut cluster_of: Vec<i32> = (0..NUM_CONTEXTS as i32).collect();
let mut active: Vec<usize> = (0..NUM_CONTEXTS).filter(|&c| totals[c] > 0).collect();
if active.is_empty() {
return LiteralContextModel {
mode,
histograms,
cmap: alloc::vec![0u8; NUM_CONTEXTS],
num_trees: 1,
est_cost_bits: 0,
};
}
let first_active = active[0];
for c in 0..NUM_CONTEXTS {
if totals[c] == 0 {
cluster_of[c] = first_active as i32;
}
}
while active.len() > 1 {
let force = active.len() > max_trees;
let mut best_i = 0usize;
let mut best_j = 0usize;
let mut best_delta: i64 = i64::MAX;
for ai in 0..active.len() {
for aj in (ai + 1)..active.len() {
let ci = active[ai];
let cj = active[aj];
let bi = histogram_bits(&histograms[ci], totals[ci]);
let bj = histogram_bits(&histograms[cj], totals[cj]);
let bm = merged_bits(&histograms[ci], totals[ci], &histograms[cj], totals[cj]);
let delta = bm as i64 - bi as i64 - bj as i64 - HEADER_COST_BITS as i64;
if delta < best_delta {
best_delta = delta;
best_i = ai;
best_j = aj;
}
}
}
if !force && best_delta > 0 {
break;
}
let ci = active[best_i];
let cj = active[best_j];
let src = histograms[cj];
for (dst, s) in histograms[ci].iter_mut().zip(src.iter()) {
*dst += *s;
}
totals[ci] += totals[cj];
for slot in cluster_of.iter_mut() {
if *slot == cj as i32 {
*slot = ci as i32;
}
}
active.swap_remove(best_j);
}
let mut remap = alloc::vec![-1i32; NUM_CONTEXTS];
let mut next = 0u8;
let mut cmap = alloc::vec![0u8; NUM_CONTEXTS];
for c in 0..NUM_CONTEXTS {
let cl = cluster_of[c] as usize;
if remap[cl] < 0 {
remap[cl] = next as i32;
next += 1;
}
cmap[c] = remap[cl] as u8;
}
let num_trees = next.max(1) as u32;
let mut data_bits: u64 = 0;
for &ci in &active {
data_bits += histogram_bits(&histograms[ci], totals[ci]);
}
let est_cost_bits = data_bits / 256 + num_trees as u64 * (HEADER_COST_BITS / 256);
LiteralContextModel {
mode,
histograms,
cmap,
num_trees,
est_cost_bits,
}
}
#[inline]
pub(crate) fn context_id(mode: ContextMode, prev1: u8, prev2: u8) -> u8 {
context::literal_context(mode, prev1, prev2)
}