#![allow(dead_code)]
use super::frequency::FrequencyCounter;
#[derive(Clone, Debug)]
pub struct ClusterResult {
pub context_map: Vec<usize>,
pub cluster_histograms: Vec<FrequencyCounter>,
pub num_clusters: usize,
pub slot_ids: Vec<usize>,
#[cfg(feature = "__debug-tokens")]
pub merge_log: Vec<(usize, usize, f64)>, }
impl ClusterResult {
pub fn new(num_contexts: usize) -> Self {
Self {
context_map: vec![0; num_contexts],
cluster_histograms: Vec::new(),
num_clusters: 0,
slot_ids: Vec::new(),
#[cfg(feature = "__debug-tokens")]
merge_log: Vec::new(),
}
}
#[inline]
pub fn get_slot(&self, context: usize) -> usize {
let cluster = self.context_map.get(context).copied().unwrap_or(0);
self.slot_ids.get(cluster).copied().unwrap_or(0)
}
#[cfg(feature = "__debug-tokens")]
pub fn dump_merge_log(&self, path: &str) -> std::io::Result<()> {
use std::io::Write;
let mut file = std::fs::File::create(path)?;
writeln!(file, "[")?;
for (i, (a, b, cost)) in self.merge_log.iter().enumerate() {
let comma = if i + 1 < self.merge_log.len() {
","
} else {
""
};
writeln!(
file,
r#" {{"ctx_a":{},"ctx_b":{},"cost_delta":{:.4}}}{}"#,
a, b, cost, comma
)?;
}
writeln!(file, "]")?;
Ok(())
}
}
#[derive(Clone, Debug)]
pub struct ContextConfig {
pub num_contexts: usize,
pub ac_offset: usize,
pub scan_ac_offsets: Vec<usize>,
}
impl ContextConfig {
pub fn for_sequential(num_components: usize) -> Self {
Self {
num_contexts: 4 + num_components, ac_offset: 4,
scan_ac_offsets: vec![4], }
}
pub fn for_progressive<I>(num_components: usize, scans: I) -> Self
where
I: Iterator<Item = (u8, u8, usize)>, {
let _ = num_components; let mut num_ac_contexts = 0;
let mut scan_ac_offsets = Vec::new();
for (_ss, se, comps_in_scan) in scans {
scan_ac_offsets.push(4 + num_ac_contexts);
if se > 0 {
num_ac_contexts += comps_in_scan;
}
}
Self {
num_contexts: 4 + num_ac_contexts,
ac_offset: 4,
scan_ac_offsets,
}
}
#[inline]
pub fn dc_context(&self, component: usize) -> usize {
component.min(3)
}
#[inline]
pub fn ac_context(&self, scan_idx: usize, comp_in_scan: usize) -> usize {
self.scan_ac_offsets
.get(scan_idx)
.map(|&offset| offset + comp_in_scan)
.unwrap_or(self.ac_offset + comp_in_scan)
}
#[inline]
pub fn num_dc_contexts(&self) -> usize {
self.ac_offset.min(4)
}
#[inline]
pub fn num_ac_contexts(&self) -> usize {
self.num_contexts.saturating_sub(self.ac_offset)
}
}
pub fn cluster_histograms(
histograms: &[FrequencyCounter],
max_clusters: usize,
force_baseline: bool,
) -> ClusterResult {
let mut result = ClusterResult::new(histograms.len());
let mut slot_histograms: Vec<usize> = Vec::new(); let mut slot_costs: Vec<f64> = Vec::new();
let effective_max = if force_baseline {
max_clusters.min(2)
} else {
max_clusters };
#[cfg(feature = "__debug-tokens")]
let mut merge_log = Vec::new();
for (ctx_idx, histo) in histograms.iter().enumerate() {
if histo.is_empty_histogram() {
result.context_map[ctx_idx] = 0;
continue;
}
let num_slots = slot_histograms.len();
let mut best_slot = num_slots;
let mut best_cost = if force_baseline && num_slots > 1 {
f64::MAX
} else if num_slots >= effective_max {
f64::MAX
} else {
histo.estimate_encoding_cost()
};
for slot_idx in 0..num_slots {
let cluster_idx = slot_histograms[slot_idx];
let prev = &result.cluster_histograms[cluster_idx];
let combined = prev.combined(histo);
let combined_cost = combined.estimate_encoding_cost();
let cost_delta = combined_cost - slot_costs[slot_idx];
if cost_delta < best_cost {
best_cost = cost_delta;
best_slot = slot_idx;
}
}
if best_slot == num_slots && num_slots < effective_max {
let cluster_idx = result.cluster_histograms.len();
result.cluster_histograms.push(histo.clone());
result.context_map[ctx_idx] = cluster_idx;
if num_slots < 4 {
slot_histograms.push(cluster_idx);
slot_costs.push(best_cost);
result.slot_ids.push(num_slots);
} else {
let replace_slot = (result.slot_ids.last().copied().unwrap_or(0) + 1) % 4;
slot_histograms[replace_slot] = cluster_idx;
slot_costs[replace_slot] = best_cost;
result.slot_ids.push(replace_slot);
}
} else {
let target_slot = if best_slot >= num_slots { 0 } else { best_slot };
let cluster_idx = slot_histograms[target_slot];
result.cluster_histograms[cluster_idx].add(histo);
result.context_map[ctx_idx] = cluster_idx;
slot_costs[target_slot] += best_cost;
#[cfg(feature = "__debug-tokens")]
merge_log.push((ctx_idx, target_slot, best_cost));
}
}
result.num_clusters = result.cluster_histograms.len();
#[cfg(feature = "__debug-tokens")]
{
result.merge_log = merge_log;
}
result
}