use hashbrown::HashMap;
use super::token::{Lz77UintCoder, Token, UintCoder};
use crate::bit_writer::BitWriter;
use crate::error::Result;
const WINDOW_SIZE: usize = 1 << 20;
const NUM_SPECIAL_DISTANCES: usize = 120;
#[rustfmt::skip]
const SPECIAL_DISTANCES: [[i8; 2]; NUM_SPECIAL_DISTANCES] = [
[0, 1], [1, 0], [1, 1], [-1, 1], [0, 2], [2, 0], [1, 2], [-1, 2],
[2, 1], [-2, 1], [2, 2], [-2, 2], [0, 3], [3, 0], [1, 3], [-1, 3],
[3, 1], [-3, 1], [2, 3], [-2, 3], [3, 2], [-3, 2], [0, 4], [4, 0],
[1, 4], [-1, 4], [4, 1], [-4, 1], [3, 3], [-3, 3], [2, 4], [-2, 4],
[4, 2], [-4, 2], [0, 5], [3, 4], [-3, 4], [4, 3], [-4, 3], [5, 0],
[1, 5], [-1, 5], [5, 1], [-5, 1], [2, 5], [-2, 5], [5, 2], [-5, 2],
[4, 4], [-4, 4], [3, 5], [-3, 5], [5, 3], [-5, 3], [0, 6], [6, 0],
[1, 6], [-1, 6], [6, 1], [-6, 1], [2, 6], [-2, 6], [6, 2], [-6, 2],
[4, 5], [-4, 5], [5, 4], [-5, 4], [3, 6], [-3, 6], [6, 3], [-6, 3],
[0, 7], [7, 0], [1, 7], [-1, 7], [5, 5], [-5, 5], [7, 1], [-7, 1],
[4, 6], [-4, 6], [6, 4], [-6, 4], [2, 7], [-2, 7], [7, 2], [-7, 2],
[3, 7], [-3, 7], [7, 3], [-7, 3], [5, 6], [-5, 6], [6, 5], [-6, 5],
[8, 0], [4, 7], [-4, 7], [7, 4], [-7, 4], [8, 1], [8, 2], [6, 6],
[-6, 6], [8, 3], [5, 7], [-5, 7], [7, 5], [-7, 5], [8, 4], [6, 7],
[-6, 7], [7, 6], [-7, 6], [8, 5], [7, 7], [-7, 7], [8, 6], [8, 7],
];
#[inline]
fn special_distance(index: usize, multiplier: i32) -> i32 {
SPECIAL_DISTANCES[index][0] as i32 + multiplier * SPECIAL_DISTANCES[index][1] as i32
}
#[rustfmt::skip]
#[allow(clippy::excessive_precision)]
const LEN_COST_TABLE: [f32; 17] = [
2.797667318563126, 3.213177690381199, 2.5706009246743737,
2.408392498667534, 2.829649191872326, 3.3923087753324577,
4.029267451554331, 4.415576699706408, 4.509357574741465,
9.21481543803004, 10.020590190114898, 11.858671627804766,
12.45853300490526, 11.713105831990857, 12.561996324849314,
13.775477692278367, 13.174027068768641,
];
#[rustfmt::skip]
#[allow(clippy::excessive_precision)]
const DIST_COST_TABLE: [f32; 139] = [
6.368282626312716, 5.680793277090298, 8.347404197105247,
7.641619201599141, 6.914328374119438, 7.959808291537444,
8.70023120759855, 8.71378518934703, 9.379132523982769,
9.110472749092708, 9.159029569270908, 9.430936766731973,
7.278284055315169, 7.8278514904267755, 10.026641158289236,
9.976049229827066, 9.64351607048908, 9.563403863480442,
10.171474111762747, 10.45950155077234, 9.994813912104219,
10.322524683741156, 8.465808729388186, 8.756254166066853,
10.160930174662234, 10.247329273413435, 10.04090403724809,
10.129398517544082, 9.342311691539546, 9.07608009102374,
10.104799540677513, 10.378079384990906, 10.165828974075072,
10.337595322341553, 7.940557464567944, 10.575665823319431,
11.023344321751955, 10.736144698831827, 11.118277044595054,
7.468468230648442, 10.738305230932939, 10.906980780216568,
10.163468216353817, 10.17805759656433, 11.167283670483565,
11.147050200274544, 10.517921919244333, 10.651764778156886,
10.17074446448919, 11.217636876224745, 11.261630721139484,
11.403140815247259, 10.892472096873417, 11.1859607804481,
8.017346947551262, 7.895143720278828, 11.036577113822025,
11.170562110315794, 10.326988722591086, 10.40872184751056,
11.213498225466386, 11.30580635516863, 10.672272515665442,
10.768069466228063, 11.145257364153565, 11.64668307145549,
10.593156194627339, 11.207499484844943, 10.767517766396908,
10.826629811407042, 10.737764794499988, 10.6200448518045,
10.191315385198092, 8.468384171390085, 11.731295299170432,
11.824619886654398, 10.41518844301179, 10.16310536548649,
10.539423685097576, 10.495136599328031, 10.469112847728267,
11.72057686174922, 10.910326337834674, 11.378921834673758,
11.847759036098536, 11.92071647623854, 10.810628276345282,
11.008601085273893, 11.910326337834674, 11.949212023423133,
11.298614839104337, 11.611603659010392, 10.472930394619985,
11.835564720850282, 11.523267392285337, 12.01055816679611,
8.413029688994023, 11.895784139536406, 11.984679534970505,
11.220654278717394, 11.716311684833672, 10.61036646226114,
10.89849965960364, 10.203762898863669, 10.997560826267238,
11.484217379438984, 11.792836176993665, 12.24310468755171,
11.464858097919262, 12.212747017409377, 11.425595666074955,
11.572048533398757, 12.742093965163013, 11.381874288645637,
12.191870445817015, 11.683156920035426, 11.152442115262197,
11.90303691580457, 11.653292787169159, 11.938615382266098,
16.970641701570223, 16.853602280380002, 17.26240782594733,
16.644655390108507, 17.14310889757499, 16.910935455445955,
17.505678976959697, 17.213498225466388,
2.4162310293553024, 3.494587244462329, 3.5258600986408344,
3.4959806589517095, 3.098390886949687, 3.343454654302911,
3.588847442290287, 4.14614790111827, 5.152948641990529,
7.433696808092598, 9.716311684833672,
];
fn len_cost(len: u32) -> f32 {
let (tok, nbits) = if len == 0 {
(0u32, 0u32)
} else {
let n = 31 - len.leading_zeros();
(1 + n, n)
};
let table_size = LEN_COST_TABLE.len();
let tok_idx = (tok as usize).min(table_size - 1);
LEN_COST_TABLE[tok_idx] + nbits as f32
}
fn dist_cost(dist: u32) -> f32 {
let (tok, nbits) = hybrid_uint_encode_7_0_0(dist);
let table_size = DIST_COST_TABLE.len();
let tok_idx = (tok as usize).min(table_size - 1);
DIST_COST_TABLE[tok_idx] + nbits as f32
}
fn hybrid_uint_encode_7_0_0(value: u32) -> (u32, u32) {
if value < 7 {
(value, 0)
} else {
let n = 31 - value.leading_zeros();
let tok = 7 + n - 3; (tok, n)
}
}
#[derive(Debug, Clone)]
pub struct Lz77Params {
pub enabled: bool,
pub min_symbol: u32,
pub min_length: u32,
pub distance_context: u32,
}
impl Lz77Params {
pub fn new(num_contexts: usize, force_huffman: bool) -> Self {
Self {
enabled: false,
min_symbol: if force_huffman { 512 } else { 224 },
min_length: 3,
distance_context: num_contexts as u32,
}
}
}
pub fn write_lz77_header(lz77: Option<&Lz77Params>, writer: &mut BitWriter) -> Result<()> {
if let Some(params) = lz77 {
writer.write(1, 1)?;
match params.min_symbol {
224 => writer.write(2, 0)?, 512 => writer.write(2, 1)?, 4096 => writer.write(2, 2)?, v => {
writer.write(2, 3)?; writer.write(15, (v - 8) as u64)?;
}
}
match params.min_length {
3 => writer.write(2, 0)?, 4 => writer.write(2, 1)?, v @ 5..=8 => {
writer.write(2, 2)?; writer.write(2, (v - 5) as u64)?;
}
v => {
writer.write(2, 3)?; writer.write(8, (v - 9) as u64)?;
}
}
writer.write(4, 0)?;
} else {
writer.write(1, 0)?; }
Ok(())
}
struct SymbolCostEstimator {
bits: Vec<f32>,
max_alphabet_size: usize,
}
impl SymbolCostEstimator {
fn new(num_contexts: usize, force_huffman: bool, tokens: &[Token], lz77: &Lz77Params) -> Self {
const ANS_LOG_TAB_SIZE: f32 = 12.0;
let mut counts: Vec<Vec<u32>> = vec![vec![]; num_contexts];
let mut total_counts = vec![0u32; num_contexts];
for token in tokens {
let (tok, _nbits) = if token.is_lz77_length() {
let e = Lz77UintCoder::encode(token.value);
(e.token + lz77.min_symbol, e.nbits)
} else {
let e = UintCoder::encode(token.value);
(e.token, e.nbits)
};
let ctx = token.context() as usize;
if ctx < num_contexts {
let sym = tok as usize;
if sym >= counts[ctx].len() {
counts[ctx].resize(sym + 1, 0);
}
counts[ctx][sym] += 1;
total_counts[ctx] += 1;
}
}
let max_alphabet_size = counts.iter().map(|c| c.len()).max().unwrap_or(0);
let mut bits = vec![0.0f32; num_contexts * max_alphabet_size];
for ctx in 0..num_contexts {
let total = total_counts[ctx];
if total == 0 {
continue;
}
let inv_total = 1.0 / (total as f32 + 1e-8);
for sym in 0..counts[ctx].len() {
let cnt = counts[ctx][sym];
let cost = if cnt != 0 && cnt != total {
let p = cnt as f32 * inv_total;
let c = -jxl_simd::fast_log2f(p);
if force_huffman { c.ceil() } else { c }
} else if cnt == 0 {
ANS_LOG_TAB_SIZE } else {
0.0 };
bits[ctx * max_alphabet_size + sym] = cost;
}
}
Self {
bits,
max_alphabet_size,
}
}
#[inline]
fn symbol_cost(&self, ctx: usize, sym: usize) -> f32 {
if sym < self.max_alphabet_size {
self.bits[ctx * self.max_alphabet_size + sym]
} else {
12.0 }
}
fn add_symbol_cost(&self, ctx: usize) -> f32 {
let mut total_cost = 0.0f32;
let mut total_count = 0u32;
for sym in 0..self.max_alphabet_size {
let cost = self.bits[ctx * self.max_alphabet_size + sym];
if cost < 12.0 {
total_cost += cost;
total_count += 1;
}
}
if total_count == 0 {
return 0.0;
}
(6.0 - total_cost / total_count as f32).max(0.0)
}
fn len_cost(&self, ctx: usize, len: u32, lz77: &Lz77Params) -> f32 {
let (tok, nbits) = if len == 0 {
(0u32, 0u32)
} else {
let n = 31 - len.leading_zeros();
(1 + n, n)
};
let sym = tok + lz77.min_symbol;
nbits as f32 + self.symbol_cost(ctx, sym as usize)
}
fn dist_cost_sce(&self, dist_symbol: u32, lz77: &Lz77Params) -> f32 {
let (tok, nbits) = UintCoder::encode(dist_symbol).into();
nbits as f32 + self.symbol_cost(lz77.distance_context as usize, tok as usize)
}
}
struct HashChain {
data: Vec<u32>,
size: usize,
window_size: usize,
window_mask: usize,
min_length: usize,
max_length: usize,
#[allow(dead_code)] hash_num_values: usize,
hash_mask: usize,
hash_shift: u32,
head: Vec<i32>,
chain: Vec<u32>,
val: Vec<i32>,
headz: Vec<i32>,
chainz: Vec<u32>,
zeros: Vec<u32>,
numzeros: u32,
special_dist_table: HashMap<i32, usize>,
num_special_distances: usize,
max_chain_length: u32,
}
impl HashChain {
fn new(
tokens: &[Token],
window_size: usize,
min_length: usize,
max_length: usize,
distance_multiplier: i32,
) -> Self {
let size = tokens.len();
let data: Vec<u32> = tokens.iter().map(|t| t.value).collect();
let hash_num_values = 32768usize;
let hash_mask = hash_num_values - 1;
let hash_shift = 5u32;
let head = vec![-1i32; hash_num_values];
let chain: Vec<u32> = (0..window_size as u32).collect(); let val = vec![-1i32; window_size];
let headz = vec![-1i32; window_size + 1];
let chainz: Vec<u32> = (0..window_size as u32).collect();
let zeros = vec![0u32; window_size];
let mut special_dist_table = HashMap::new();
let num_special_distances = if distance_multiplier != 0 {
for i in (0..NUM_SPECIAL_DISTANCES).rev() {
let dist = special_distance(i, distance_multiplier);
if dist > 0 {
special_dist_table.insert(dist, i);
}
}
NUM_SPECIAL_DISTANCES
} else {
0
};
Self {
data,
size,
window_size,
window_mask: window_size - 1,
min_length,
max_length,
hash_num_values,
hash_mask,
hash_shift,
head,
chain,
val,
headz,
chainz,
zeros,
numzeros: 0,
special_dist_table,
num_special_distances,
max_chain_length: 256,
}
}
fn get_hash(&self, pos: usize) -> u32 {
if pos + 2 >= self.size {
return 0;
}
let mut result = 0u32;
result ^= self.data[pos] & 0xFFFF;
result ^= (self.data[pos + 1] & 0xFFFF) << self.hash_shift;
result ^= (self.data[pos + 2] & 0xFFFF) << (self.hash_shift * 2);
result & self.hash_mask as u32
}
fn count_zeros(&self, pos: usize, prev_zeros: u32) -> u32 {
let end = (pos + self.window_size).min(self.size);
if prev_zeros > 0 {
if prev_zeros >= self.window_mask as u32
&& self.data[end - 1] == 0
&& end == pos + self.window_size
{
return prev_zeros;
} else {
return prev_zeros - 1;
}
}
let mut num = 0u32;
while pos + (num as usize) < end && self.data[pos + (num as usize)] == 0 {
num += 1;
}
num
}
fn update(&mut self, pos: usize) {
let hashval = self.get_hash(pos);
let wpos = pos & self.window_mask;
self.val[wpos] = hashval as i32;
if self.head[hashval as usize] != -1 {
self.chain[wpos] = self.head[hashval as usize] as u32;
}
self.head[hashval as usize] = wpos as i32;
if pos > 0 && self.data[pos] != self.data[pos - 1] {
self.numzeros = 0;
}
self.numzeros = self.count_zeros(pos, self.numzeros);
self.zeros[wpos] = self.numzeros;
if self.headz[self.numzeros as usize] != -1 {
self.chainz[wpos] = self.headz[self.numzeros as usize] as u32;
}
self.headz[self.numzeros as usize] = wpos as i32;
}
fn update_range(&mut self, pos: usize, len: usize) {
for i in 0..len {
self.update(pos + i);
}
}
fn find_match(&self, pos: usize, max_dist: usize) -> (usize, usize) {
let mut best_dist_symbol = 0usize;
let mut best_len = 1usize;
self.find_matches(pos, max_dist, |len, dist_symbol| {
if len > best_len || (len == best_len && dist_symbol < best_dist_symbol) {
best_len = len;
best_dist_symbol = dist_symbol;
}
});
(best_dist_symbol, best_len)
}
fn find_matches<F>(&self, pos: usize, max_dist: usize, mut found_match: F)
where
F: FnMut(usize, usize),
{
let wpos = pos & self.window_mask;
let hashval = self.get_hash(pos);
let mut hashpos = self.chain[wpos];
let mut prev_dist = 0i32;
let end = (pos + self.max_length).min(self.size);
let mut chain_length = 0u32;
let mut best_len = 0usize;
loop {
let dist = if hashpos as usize <= wpos {
wpos - hashpos as usize
} else {
wpos + self.window_mask + 1 - hashpos as usize
};
if (dist as i32) < prev_dist {
break;
}
prev_dist = dist as i32;
if dist > 0 && dist <= max_dist {
let mut i = pos;
let mut j = pos - dist;
if self.numzeros > 3 {
let r =
((self.numzeros - 1) as usize).min(self.zeros[hashpos as usize] as usize);
let skip = if i + r >= end { end - i - 1 } else { r };
i += skip;
j += skip;
}
while i < end && self.data[i] == self.data[j] {
i += 1;
j += 1;
}
let len = i - pos;
if len >= self.min_length && len + 2 >= best_len {
let dist_symbol =
if let Some(&sym) = self.special_dist_table.get(&(dist as i32)) {
sym
} else {
self.num_special_distances + dist - 1
};
found_match(len, dist_symbol);
if len > best_len {
best_len = len;
}
}
}
chain_length += 1;
if chain_length >= self.max_chain_length {
break;
}
if self.numzeros >= 3 && best_len > self.numzeros as usize {
if hashpos == self.chainz[hashpos as usize] {
break;
}
hashpos = self.chainz[hashpos as usize];
if self.zeros[hashpos as usize] != self.numzeros {
break;
}
} else {
if hashpos == self.chain[hashpos as usize] {
break;
}
hashpos = self.chain[hashpos as usize];
if self.val[hashpos as usize] != hashval as i32 {
break;
}
}
}
}
}
pub fn apply_lz77_backref(
tokens: &[Token],
num_contexts: usize,
force_huffman: bool,
distance_multiplier: i32,
) -> Option<(Vec<Token>, Lz77Params)> {
if tokens.is_empty() {
return None;
}
let mut lz77 = Lz77Params::new(num_contexts, force_huffman);
let sce = SymbolCostEstimator::new(num_contexts, force_huffman, tokens, &lz77);
let mut sym_cost = vec![0.0f32; tokens.len() + 1];
for (i, token) in tokens.iter().enumerate() {
let e = UintCoder::encode(token.value);
let cost = sce.symbol_cost(token.context() as usize, e.token as usize) + e.nbits as f32;
sym_cost[i + 1] = sym_cost[i] + cost;
}
let mut out = Vec::with_capacity(tokens.len());
let mut bit_decrease: f32 = 0.0;
let total_symbols = tokens.len();
let max_distance = tokens.len();
let min_length = lz77.min_length as usize;
let max_length = tokens.len();
let mut window_size = 1usize;
while window_size < max_distance && window_size < WINDOW_SIZE {
window_size <<= 1;
}
let mut chain = HashChain::new(
tokens,
window_size,
min_length,
max_length,
distance_multiplier,
);
const MAX_LAZY_MATCH_LEN: usize = 256;
let mut already_updated = false;
let mut i = 0usize;
while i < tokens.len() {
out.push(tokens[i]);
if !already_updated {
chain.update(i);
}
already_updated = false;
let (mut dist_symbol, mut len) = chain.find_match(i, max_distance);
if len >= min_length {
if len < MAX_LAZY_MATCH_LEN && i + 1 < tokens.len() {
chain.update(i + 1);
already_updated = true;
let (dist_symbol2, len2) = chain.find_match(i + 1, max_distance);
if len2 > len {
i += 1;
already_updated = false;
len = len2;
dist_symbol = dist_symbol2;
out.push(tokens[i]);
}
}
let literal_cost = sym_cost[i + len] - sym_cost[i];
let lz77_len = len - min_length;
let lz77_cost = len_cost(lz77_len as u32)
+ dist_cost(dist_symbol as u32)
+ sce.add_symbol_cost(out.last().unwrap().context() as usize);
if lz77_cost <= literal_cost {
let last_token = out.last_mut().unwrap();
last_token.value = lz77_len as u32;
last_token.set_lz77_length(true);
out.push(Token::new(lz77.distance_context, dist_symbol as u32));
bit_decrease += literal_cost - lz77_cost;
} else {
for j in 1..len {
out.push(tokens[i + j]);
}
}
if already_updated {
chain.update_range(i + 2, len - 2);
already_updated = false;
} else {
chain.update_range(i + 1, len - 1);
}
i += len - 1;
}
i += 1;
}
let threshold = total_symbols as f32 * 0.2 + 16.0;
#[cfg(feature = "debug-tokens")]
eprintln!(
"[LZ77-backref] bit_decrease={:.1}, threshold={:.1}, tokens: {} -> {}, matches={}",
bit_decrease,
threshold,
total_symbols,
out.len(),
out.iter().filter(|t| t.is_lz77_length()).count()
);
if bit_decrease > threshold {
lz77.enabled = true;
Some((out, lz77))
} else {
None
}
}
pub fn apply_lz77_rle(
tokens: &[Token],
num_contexts: usize,
force_huffman: bool,
distance_multiplier: i32,
) -> Option<(Vec<Token>, Lz77Params)> {
if tokens.is_empty() {
return None;
}
let mut lz77 = Lz77Params::new(num_contexts, force_huffman);
let rle_distance_symbol: u32 = if distance_multiplier > 0 { 1 } else { 0 };
let sce = SymbolCostEstimator::new(num_contexts, force_huffman, tokens, &lz77);
let mut sym_cost = vec![0.0f32; tokens.len() + 1];
for (i, token) in tokens.iter().enumerate() {
let e = UintCoder::encode(token.value);
let cost = sce.symbol_cost(token.context() as usize, e.token as usize) + e.nbits as f32;
sym_cost[i + 1] = sym_cost[i] + cost;
}
let mut out = Vec::with_capacity(tokens.len());
let mut bit_decrease: f32 = 0.0;
let total_symbols = tokens.len();
let mut i = 0;
while i < tokens.len() {
let mut num_to_copy = 0;
if i > 0 {
let prev_value = tokens[i - 1].value;
while i + num_to_copy < tokens.len() && tokens[i + num_to_copy].value == prev_value {
num_to_copy += 1;
}
}
if num_to_copy == 0 {
out.push(tokens[i]);
i += 1;
continue;
}
let literal_cost = sym_cost[i + num_to_copy] - sym_cost[i];
let lz77_cost = if num_to_copy >= lz77.min_length as usize {
let lz77_len = num_to_copy - lz77.min_length as usize;
ceil_log2_nonzero((lz77_len + 1) as u32) as f32 + 1.0
} else {
0.0
};
if num_to_copy < lz77.min_length as usize || literal_cost <= lz77_cost {
for j in 0..num_to_copy {
out.push(tokens[i + j]);
}
i += num_to_copy;
continue;
}
let lz77_len = (num_to_copy - lz77.min_length as usize) as u32;
out.push(Token::lz77_length(tokens[i].context(), lz77_len));
out.push(Token::new(lz77.distance_context, rle_distance_symbol));
bit_decrease += literal_cost - lz77_cost;
i += num_to_copy;
}
let threshold = total_symbols as f32 * 0.2 + 16.0;
#[cfg(feature = "debug-tokens")]
eprintln!(
"[LZ77-RLE] bit_decrease={:.1}, threshold={:.1}, tokens: {} -> {}, runs_found={}",
bit_decrease,
threshold,
total_symbols,
out.len(),
out.iter().filter(|t| t.is_lz77_length()).count()
);
if bit_decrease > threshold {
lz77.enabled = true;
Some((out, lz77))
} else {
None
}
}
fn ceil_log2_nonzero(x: u32) -> u32 {
debug_assert!(x > 0);
let floor = 31 - x.leading_zeros();
if x.is_power_of_two() {
floor
} else {
floor + 1
}
}
#[derive(Debug, Clone, Copy, PartialEq, Eq, Default)]
pub enum Lz77Method {
#[default]
Rle,
Greedy,
Optimal,
}
pub fn apply_lz77(
tokens: &[Token],
num_contexts: usize,
force_huffman: bool,
method: Lz77Method,
distance_multiplier: i32,
) -> Option<(Vec<Token>, Lz77Params)> {
match method {
Lz77Method::Rle => apply_lz77_rle(tokens, num_contexts, force_huffman, distance_multiplier),
Lz77Method::Greedy => {
apply_lz77_backref(tokens, num_contexts, force_huffman, distance_multiplier)
}
Lz77Method::Optimal => {
apply_lz77_optimal(tokens, num_contexts, force_huffman, distance_multiplier)
}
}
}
pub fn apply_lz77_optimal(
tokens: &[Token],
num_contexts: usize,
force_huffman: bool,
distance_multiplier: i32,
) -> Option<(Vec<Token>, Lz77Params)> {
if tokens.is_empty() {
return None;
}
let greedy_result =
apply_lz77_backref(tokens, num_contexts, force_huffman, distance_multiplier);
let greedy_tokens = match &greedy_result {
Some((t, _)) => t,
None => return None,
};
let mut lz77 = Lz77Params::new(num_contexts, force_huffman);
lz77.enabled = true;
let sce = SymbolCostEstimator::new(num_contexts + 1, force_huffman, greedy_tokens, &lz77);
let mut sym_cost = vec![0.0f32; tokens.len() + 1];
for (i, token) in tokens.iter().enumerate() {
let e = UintCoder::encode(token.value);
let cost = sce.symbol_cost(token.context() as usize, e.token as usize) + e.nbits as f32;
sym_cost[i + 1] = sym_cost[i] + cost;
}
let max_distance = tokens.len();
let min_length = lz77.min_length as usize;
let max_length = tokens.len();
let mut window_size = 1usize;
while window_size < max_distance && window_size < WINDOW_SIZE {
window_size <<= 1;
}
let mut chain = HashChain::new(
tokens,
window_size,
min_length,
max_length,
distance_multiplier,
);
struct PrefixInfo {
len: u32,
dist_symbol: u32, ctx: u32,
total_cost: f32,
}
let n = tokens.len();
let mut prefix_costs: Vec<PrefixInfo> = (0..=n)
.map(|_| PrefixInfo {
len: 0,
dist_symbol: 0,
ctx: 0,
total_cost: f32::MAX,
})
.collect();
prefix_costs[0].total_cost = 0.0;
let mut rle_length = 0usize;
let mut skip_lz77 = 0usize;
let mut dist_symbols: Vec<u32> = Vec::new();
for i in 0..n {
chain.update(i);
let lit_cost = prefix_costs[i].total_cost + sym_cost[i + 1] - sym_cost[i];
if prefix_costs[i + 1].total_cost > lit_cost {
prefix_costs[i + 1].dist_symbol = 0;
prefix_costs[i + 1].len = 1;
prefix_costs[i + 1].ctx = tokens[i].context();
prefix_costs[i + 1].total_cost = lit_cost;
}
if skip_lz77 > 0 {
skip_lz77 -= 1;
continue;
}
dist_symbols.clear();
chain.find_matches(i, max_distance, |len, dist_symbol| {
if dist_symbols.len() <= len {
dist_symbols.resize(len + 1, dist_symbol as u32);
}
if (dist_symbol as u32) < dist_symbols[len] {
dist_symbols[len] = dist_symbol as u32;
}
});
if dist_symbols.len() <= min_length {
continue;
}
{
let mut best_cost = dist_symbols[dist_symbols.len() - 1];
for j in (min_length..dist_symbols.len()).rev() {
if dist_symbols[j] < best_cost {
best_cost = dist_symbols[j];
}
dist_symbols[j] = best_cost;
}
}
for (j, &dsym) in dist_symbols.iter().enumerate().skip(min_length) {
let target = i + j;
if target > n {
break;
}
let lz77_cost =
sce.len_cost(tokens[i].context() as usize, (j - min_length) as u32, &lz77)
+ sce.dist_cost_sce(dsym, &lz77);
let cost = prefix_costs[i].total_cost + lz77_cost;
if prefix_costs[target].total_cost > cost {
prefix_costs[target].len = j as u32;
prefix_costs[target].dist_symbol = dsym + 1; prefix_costs[target].ctx = tokens[i].context();
prefix_costs[target].total_cost = cost;
}
}
let last_dist = dist_symbols[dist_symbols.len() - 1];
if (last_dist == 0 && distance_multiplier == 0)
|| (last_dist == 1 && distance_multiplier != 0)
{
rle_length += 1;
} else {
rle_length = 0;
}
if rle_length >= 8 && dist_symbols.len() > 9 {
skip_lz77 = dist_symbols.len() - 10;
rle_length = 0;
}
}
let mut out = Vec::with_capacity(n);
let mut pos = n;
while pos > 0 {
let info = &prefix_costs[pos];
let is_lz77 = info.dist_symbol != 0;
if is_lz77 {
let dist_symbol = info.dist_symbol - 1;
out.push(Token::new(lz77.distance_context, dist_symbol));
}
let val = if is_lz77 {
info.len - min_length as u32
} else {
tokens[pos - 1].value
};
let mut tok = Token::new(info.ctx, val);
tok.set_lz77_length(is_lz77);
out.push(tok);
pos -= info.len as usize;
}
out.reverse();
Some((out, lz77))
}
#[allow(dead_code)] pub fn apply_lz77_best(
tokens: &[Token],
num_contexts: usize,
force_huffman: bool,
distance_multiplier: i32,
) -> Option<(Vec<Token>, Lz77Params)> {
let rle_result = apply_lz77_rle(tokens, num_contexts, force_huffman, distance_multiplier);
let backref_result =
apply_lz77_backref(tokens, num_contexts, force_huffman, distance_multiplier);
match (&rle_result, &backref_result) {
(Some((rle_tokens, _)), Some((backref_tokens, _))) => {
if backref_tokens.len() <= rle_tokens.len() {
backref_result
} else {
rle_result
}
}
(Some(_), None) => rle_result,
(None, Some(_)) => backref_result,
(None, None) => None,
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_ceil_log2_nonzero() {
assert_eq!(ceil_log2_nonzero(1), 0);
assert_eq!(ceil_log2_nonzero(2), 1);
assert_eq!(ceil_log2_nonzero(3), 2);
assert_eq!(ceil_log2_nonzero(4), 2);
assert_eq!(ceil_log2_nonzero(5), 3);
assert_eq!(ceil_log2_nonzero(8), 3);
assert_eq!(ceil_log2_nonzero(9), 4);
}
#[test]
fn test_no_rle_on_short_stream() {
let tokens = vec![Token::new(0, 5), Token::new(0, 5), Token::new(0, 5)];
assert!(apply_lz77_rle(&tokens, 1, false, 0).is_none());
}
#[test]
fn test_rle_on_long_run() {
let mut tokens = Vec::new();
tokens.push(Token::new(0, 5));
for _ in 0..200 {
tokens.push(Token::new(0, 5));
}
let result = apply_lz77_rle(&tokens, 1, false, 0);
if let Some((lz77_tokens, params)) = result {
assert!(params.enabled);
assert!(lz77_tokens.len() < tokens.len());
assert!(lz77_tokens.iter().any(|t| t.is_lz77_length()));
}
}
#[test]
fn test_rle_preserves_non_runs() {
let mut tokens = Vec::new();
for i in 0..10 {
tokens.push(Token::new(0, i));
}
for _ in 0..100 {
tokens.push(Token::new(0, 42));
}
for i in 0..10 {
tokens.push(Token::new(0, i + 100));
}
if let Some((lz77_tokens, params)) = apply_lz77_rle(&tokens, 1, false, 0) {
assert!(params.enabled);
assert!(lz77_tokens.len() < tokens.len());
assert_eq!(lz77_tokens[0].value, 0);
assert!(!lz77_tokens[0].is_lz77_length());
}
}
#[test]
fn test_empty_stream() {
assert!(apply_lz77_rle(&[], 1, false, 0).is_none());
}
#[test]
fn test_backref_empty_stream() {
assert!(apply_lz77_backref(&[], 1, false, 0).is_none());
}
#[test]
fn test_backref_short_stream() {
let tokens = vec![Token::new(0, 5), Token::new(0, 5), Token::new(0, 5)];
assert!(apply_lz77_backref(&tokens, 1, false, 0).is_none());
}
#[test]
fn test_backref_on_repeating_pattern() {
let mut tokens = Vec::new();
for _ in 0..100 {
tokens.push(Token::new(0, 10));
tokens.push(Token::new(0, 20));
tokens.push(Token::new(0, 30));
}
let result = apply_lz77_backref(&tokens, 1, false, 0);
if let Some((lz77_tokens, params)) = result {
assert!(params.enabled);
assert!(
lz77_tokens.len() < tokens.len(),
"backref should compress pattern: {} vs {}",
lz77_tokens.len(),
tokens.len()
);
assert!(lz77_tokens.iter().any(|t| t.is_lz77_length()));
}
}
#[test]
fn test_backref_finds_longer_matches_than_rle() {
let mut tokens = Vec::new();
for _ in 0..50 {
for j in 1..=5 {
tokens.push(Token::new(0, j));
}
}
let rle_result = apply_lz77_rle(&tokens, 1, false, 0);
let backref_result = apply_lz77_backref(&tokens, 1, false, 0);
match (&rle_result, &backref_result) {
(None, Some((backref_tokens, _))) => {
assert!(backref_tokens.len() < tokens.len());
}
(Some((rle_tokens, _)), Some((backref_tokens, _))) => {
assert!(backref_tokens.len() <= rle_tokens.len());
}
_ => {
}
}
}
#[test]
fn test_backref_with_distance_multiplier() {
let mut tokens = Vec::new();
let image_width = 64;
for _row in 0..20 {
for col in 0..image_width {
tokens.push(Token::new(0, (col % 16) as u32));
}
}
let _result_no_mult = apply_lz77_backref(&tokens, 1, false, 0);
let result_with_mult = apply_lz77_backref(&tokens, 1, false, image_width);
if let Some((tokens_mult, params)) = result_with_mult {
assert!(params.enabled);
assert!(tokens_mult.len() < tokens.len());
}
}
#[test]
fn test_special_distance() {
assert_eq!(special_distance(0, 64), 64);
assert_eq!(special_distance(1, 64), 1);
assert_eq!(special_distance(2, 64), 65);
assert_eq!(special_distance(3, 64), 63);
}
#[test]
fn test_len_cost() {
for len in 0..1000 {
let cost = len_cost(len);
assert!(cost >= 0.0, "len_cost({}) should be non-negative", len);
assert!(cost < 100.0, "len_cost({}) should be reasonable", len);
}
}
#[test]
fn test_dist_cost() {
for dist in 0..10000 {
let cost = dist_cost(dist);
assert!(cost >= 0.0, "dist_cost({}) should be non-negative", dist);
assert!(cost < 100.0, "dist_cost({}) should be reasonable", dist);
}
}
#[test]
fn test_apply_lz77_method_enum() {
let mut tokens = Vec::new();
tokens.push(Token::new(0, 5));
for _ in 0..200 {
tokens.push(Token::new(0, 5));
}
let rle_result = apply_lz77(&tokens, 1, false, Lz77Method::Rle, 0);
if let Some((_, params)) = &rle_result {
assert!(params.enabled);
}
let greedy_result = apply_lz77(&tokens, 1, false, Lz77Method::Greedy, 0);
if let Some((_, params)) = &greedy_result {
assert!(params.enabled);
}
}
#[test]
fn test_apply_lz77_best() {
let mut tokens = Vec::new();
for _ in 0..50 {
for j in 1..=10 {
tokens.push(Token::new(0, j));
}
}
let best_result = apply_lz77_best(&tokens, 1, false, 0);
if let Some((best_tokens, params)) = best_result {
assert!(params.enabled);
assert!(best_tokens.len() < tokens.len());
}
}
#[test]
fn test_hash_chain_basic() {
let tokens = vec![
Token::new(0, 10),
Token::new(0, 20),
Token::new(0, 30),
Token::new(0, 40), Token::new(0, 10),
Token::new(0, 20),
Token::new(0, 30), ];
let mut chain = HashChain::new(&tokens, 16, 3, 100, 0);
for i in 0..tokens.len() {
chain.update(i);
}
let (dist_symbol, len) = chain.find_match(4, 10);
assert!(len >= 3, "should find match of length >= 3, got {}", len);
assert_eq!(dist_symbol, 3, "distance symbol for dist=4 should be 3");
}
}