1use hashbrown::HashMap;
16
17use super::token::{Lz77UintCoder, Token, UintCoder};
18use crate::bit_writer::BitWriter;
19use crate::error::Result;
20
21const WINDOW_SIZE: usize = 1 << 20;
23
24const NUM_SPECIAL_DISTANCES: usize = 120;
26
27#[rustfmt::skip]
31const SPECIAL_DISTANCES: [[i8; 2]; NUM_SPECIAL_DISTANCES] = [
32 [0, 1], [1, 0], [1, 1], [-1, 1], [0, 2], [2, 0], [1, 2], [-1, 2],
33 [2, 1], [-2, 1], [2, 2], [-2, 2], [0, 3], [3, 0], [1, 3], [-1, 3],
34 [3, 1], [-3, 1], [2, 3], [-2, 3], [3, 2], [-3, 2], [0, 4], [4, 0],
35 [1, 4], [-1, 4], [4, 1], [-4, 1], [3, 3], [-3, 3], [2, 4], [-2, 4],
36 [4, 2], [-4, 2], [0, 5], [3, 4], [-3, 4], [4, 3], [-4, 3], [5, 0],
37 [1, 5], [-1, 5], [5, 1], [-5, 1], [2, 5], [-2, 5], [5, 2], [-5, 2],
38 [4, 4], [-4, 4], [3, 5], [-3, 5], [5, 3], [-5, 3], [0, 6], [6, 0],
39 [1, 6], [-1, 6], [6, 1], [-6, 1], [2, 6], [-2, 6], [6, 2], [-6, 2],
40 [4, 5], [-4, 5], [5, 4], [-5, 4], [3, 6], [-3, 6], [6, 3], [-6, 3],
41 [0, 7], [7, 0], [1, 7], [-1, 7], [5, 5], [-5, 5], [7, 1], [-7, 1],
42 [4, 6], [-4, 6], [6, 4], [-6, 4], [2, 7], [-2, 7], [7, 2], [-7, 2],
43 [3, 7], [-3, 7], [7, 3], [-7, 3], [5, 6], [-5, 6], [6, 5], [-6, 5],
44 [8, 0], [4, 7], [-4, 7], [7, 4], [-7, 4], [8, 1], [8, 2], [6, 6],
45 [-6, 6], [8, 3], [5, 7], [-5, 7], [7, 5], [-7, 5], [8, 4], [6, 7],
46 [-6, 7], [7, 6], [-7, 6], [8, 5], [7, 7], [-7, 7], [8, 6], [8, 7],
47];
48
49#[inline]
51fn special_distance(index: usize, multiplier: i32) -> i32 {
52 SPECIAL_DISTANCES[index][0] as i32 + multiplier * SPECIAL_DISTANCES[index][1] as i32
53}
54
55#[rustfmt::skip]
58#[allow(clippy::excessive_precision)]
59const LEN_COST_TABLE: [f32; 17] = [
60 2.797667318563126, 3.213177690381199, 2.5706009246743737,
61 2.408392498667534, 2.829649191872326, 3.3923087753324577,
62 4.029267451554331, 4.415576699706408, 4.509357574741465,
63 9.21481543803004, 10.020590190114898, 11.858671627804766,
64 12.45853300490526, 11.713105831990857, 12.561996324849314,
65 13.775477692278367, 13.174027068768641,
66];
67
68#[rustfmt::skip]
71#[allow(clippy::excessive_precision)]
72const DIST_COST_TABLE: [f32; 139] = [
73 6.368282626312716, 5.680793277090298, 8.347404197105247,
74 7.641619201599141, 6.914328374119438, 7.959808291537444,
75 8.70023120759855, 8.71378518934703, 9.379132523982769,
76 9.110472749092708, 9.159029569270908, 9.430936766731973,
77 7.278284055315169, 7.8278514904267755, 10.026641158289236,
78 9.976049229827066, 9.64351607048908, 9.563403863480442,
79 10.171474111762747, 10.45950155077234, 9.994813912104219,
80 10.322524683741156, 8.465808729388186, 8.756254166066853,
81 10.160930174662234, 10.247329273413435, 10.04090403724809,
82 10.129398517544082, 9.342311691539546, 9.07608009102374,
83 10.104799540677513, 10.378079384990906, 10.165828974075072,
84 10.337595322341553, 7.940557464567944, 10.575665823319431,
85 11.023344321751955, 10.736144698831827, 11.118277044595054,
86 7.468468230648442, 10.738305230932939, 10.906980780216568,
87 10.163468216353817, 10.17805759656433, 11.167283670483565,
88 11.147050200274544, 10.517921919244333, 10.651764778156886,
89 10.17074446448919, 11.217636876224745, 11.261630721139484,
90 11.403140815247259, 10.892472096873417, 11.1859607804481,
91 8.017346947551262, 7.895143720278828, 11.036577113822025,
92 11.170562110315794, 10.326988722591086, 10.40872184751056,
93 11.213498225466386, 11.30580635516863, 10.672272515665442,
94 10.768069466228063, 11.145257364153565, 11.64668307145549,
95 10.593156194627339, 11.207499484844943, 10.767517766396908,
96 10.826629811407042, 10.737764794499988, 10.6200448518045,
97 10.191315385198092, 8.468384171390085, 11.731295299170432,
98 11.824619886654398, 10.41518844301179, 10.16310536548649,
99 10.539423685097576, 10.495136599328031, 10.469112847728267,
100 11.72057686174922, 10.910326337834674, 11.378921834673758,
101 11.847759036098536, 11.92071647623854, 10.810628276345282,
102 11.008601085273893, 11.910326337834674, 11.949212023423133,
103 11.298614839104337, 11.611603659010392, 10.472930394619985,
104 11.835564720850282, 11.523267392285337, 12.01055816679611,
105 8.413029688994023, 11.895784139536406, 11.984679534970505,
106 11.220654278717394, 11.716311684833672, 10.61036646226114,
107 10.89849965960364, 10.203762898863669, 10.997560826267238,
108 11.484217379438984, 11.792836176993665, 12.24310468755171,
109 11.464858097919262, 12.212747017409377, 11.425595666074955,
110 11.572048533398757, 12.742093965163013, 11.381874288645637,
111 12.191870445817015, 11.683156920035426, 11.152442115262197,
112 11.90303691580457, 11.653292787169159, 11.938615382266098,
113 16.970641701570223, 16.853602280380002, 17.26240782594733,
114 16.644655390108507, 17.14310889757499, 16.910935455445955,
115 17.505678976959697, 17.213498225466388,
116 2.4162310293553024, 3.494587244462329, 3.5258600986408344,
121 3.4959806589517095, 3.098390886949687, 3.343454654302911,
122 3.588847442290287, 4.14614790111827, 5.152948641990529,
123 7.433696808092598, 9.716311684833672,
124];
125
126fn len_cost(len: u32) -> f32 {
128 let (tok, nbits) = if len == 0 {
130 (0u32, 0u32)
131 } else {
132 let n = 31 - len.leading_zeros();
133 (1 + n, n)
134 };
135 let table_size = LEN_COST_TABLE.len();
136 let tok_idx = (tok as usize).min(table_size - 1);
137 LEN_COST_TABLE[tok_idx] + nbits as f32
138}
139
140fn dist_cost(dist: u32) -> f32 {
142 let (tok, nbits) = hybrid_uint_encode_7_0_0(dist);
144 let table_size = DIST_COST_TABLE.len();
145 let tok_idx = (tok as usize).min(table_size - 1);
146 DIST_COST_TABLE[tok_idx] + nbits as f32
147}
148
149fn hybrid_uint_encode_7_0_0(value: u32) -> (u32, u32) {
151 if value < 7 {
155 (value, 0)
156 } else {
157 let n = 31 - value.leading_zeros();
158 let tok = 7 + n - 3; (tok, n)
160 }
161}
162
163#[derive(Debug, Clone)]
165pub struct Lz77Params {
166 pub enabled: bool,
167 pub min_symbol: u32,
170 pub min_length: u32,
172 pub distance_context: u32,
174}
175
176impl Lz77Params {
177 pub fn new(num_contexts: usize, force_huffman: bool) -> Self {
178 Self {
179 enabled: false,
180 min_symbol: if force_huffman { 512 } else { 224 },
181 min_length: 3,
182 distance_context: num_contexts as u32,
183 }
184 }
185}
186
187pub fn write_lz77_header(lz77: Option<&Lz77Params>, writer: &mut BitWriter) -> Result<()> {
201 if let Some(params) = lz77 {
202 writer.write(1, 1)?; match params.min_symbol {
206 224 => writer.write(2, 0)?, 512 => writer.write(2, 1)?, 4096 => writer.write(2, 2)?, v => {
210 writer.write(2, 3)?; writer.write(15, (v - 8) as u64)?;
212 }
213 }
214
215 match params.min_length {
217 3 => writer.write(2, 0)?, 4 => writer.write(2, 1)?, v @ 5..=8 => {
220 writer.write(2, 2)?; writer.write(2, (v - 5) as u64)?;
222 }
223 v => {
224 writer.write(2, 3)?; writer.write(8, (v - 9) as u64)?;
226 }
227 }
228
229 writer.write(4, 0)?;
232 } else {
233 writer.write(1, 0)?; }
235 Ok(())
236}
237
238struct SymbolCostEstimator {
240 bits: Vec<f32>,
242 max_alphabet_size: usize,
243}
244
245impl SymbolCostEstimator {
246 fn new(num_contexts: usize, force_huffman: bool, tokens: &[Token], lz77: &Lz77Params) -> Self {
247 const ANS_LOG_TAB_SIZE: f32 = 12.0;
248
249 let mut counts: Vec<Vec<u32>> = vec![vec![]; num_contexts];
251 let mut total_counts = vec![0u32; num_contexts];
252
253 for token in tokens {
254 let (tok, _nbits) = if token.is_lz77_length() {
255 let e = Lz77UintCoder::encode(token.value);
256 (e.token + lz77.min_symbol, e.nbits)
257 } else {
258 let e = UintCoder::encode(token.value);
259 (e.token, e.nbits)
260 };
261 let ctx = token.context() as usize;
262 if ctx < num_contexts {
263 let sym = tok as usize;
264 if sym >= counts[ctx].len() {
265 counts[ctx].resize(sym + 1, 0);
266 }
267 counts[ctx][sym] += 1;
268 total_counts[ctx] += 1;
269 }
270 }
271
272 let max_alphabet_size = counts.iter().map(|c| c.len()).max().unwrap_or(0);
273 let mut bits = vec![0.0f32; num_contexts * max_alphabet_size];
274
275 for ctx in 0..num_contexts {
276 let total = total_counts[ctx];
277 if total == 0 {
278 continue;
279 }
280 let inv_total = 1.0 / (total as f32 + 1e-8);
281 for sym in 0..counts[ctx].len() {
282 let cnt = counts[ctx][sym];
283 let cost = if cnt != 0 && cnt != total {
284 let p = cnt as f32 * inv_total;
285 let c = -jxl_simd::fast_log2f(p);
286 if force_huffman { c.ceil() } else { c }
287 } else if cnt == 0 {
288 ANS_LOG_TAB_SIZE } else {
290 0.0 };
292 bits[ctx * max_alphabet_size + sym] = cost;
293 }
294 }
295
296 Self {
297 bits,
298 max_alphabet_size,
299 }
300 }
301
302 #[inline]
303 fn symbol_cost(&self, ctx: usize, sym: usize) -> f32 {
304 if sym < self.max_alphabet_size {
305 self.bits[ctx * self.max_alphabet_size + sym]
306 } else {
307 12.0 }
309 }
310
311 fn add_symbol_cost(&self, ctx: usize) -> f32 {
313 let mut total_cost = 0.0f32;
315 let mut total_count = 0u32;
316 for sym in 0..self.max_alphabet_size {
317 let cost = self.bits[ctx * self.max_alphabet_size + sym];
318 if cost < 12.0 {
319 total_cost += cost;
321 total_count += 1;
322 }
323 }
324 if total_count == 0 {
325 return 0.0;
326 }
327 (6.0 - total_cost / total_count as f32).max(0.0)
329 }
330
331 fn len_cost(&self, ctx: usize, len: u32, lz77: &Lz77Params) -> f32 {
333 let (tok, nbits) = if len == 0 {
335 (0u32, 0u32)
336 } else {
337 let n = 31 - len.leading_zeros();
338 (1 + n, n)
339 };
340 let sym = tok + lz77.min_symbol;
341 nbits as f32 + self.symbol_cost(ctx, sym as usize)
342 }
343
344 fn dist_cost_sce(&self, dist_symbol: u32, lz77: &Lz77Params) -> f32 {
346 let (tok, nbits) = UintCoder::encode(dist_symbol).into();
347 nbits as f32 + self.symbol_cost(lz77.distance_context as usize, tok as usize)
348 }
349}
350
351struct HashChain {
356 data: Vec<u32>,
358 size: usize,
360 window_size: usize,
362 window_mask: usize,
364 min_length: usize,
366 max_length: usize,
368
369 #[allow(dead_code)] hash_num_values: usize,
372 hash_mask: usize,
373 hash_shift: u32,
374
375 head: Vec<i32>,
377 chain: Vec<u32>,
379 val: Vec<i32>,
381
382 headz: Vec<i32>,
385 chainz: Vec<u32>,
387 zeros: Vec<u32>,
389 numzeros: u32,
391
392 special_dist_table: HashMap<i32, usize>,
394 num_special_distances: usize,
396
397 max_chain_length: u32,
399}
400
401impl HashChain {
402 fn new(
403 tokens: &[Token],
404 window_size: usize,
405 min_length: usize,
406 max_length: usize,
407 distance_multiplier: i32,
408 ) -> Self {
409 let size = tokens.len();
410
411 let data: Vec<u32> = tokens.iter().map(|t| t.value).collect();
413
414 let hash_num_values = 32768usize;
416 let hash_mask = hash_num_values - 1;
417 let hash_shift = 5u32;
418
419 let head = vec![-1i32; hash_num_values];
420 let chain: Vec<u32> = (0..window_size as u32).collect(); let val = vec![-1i32; window_size];
422
423 let headz = vec![-1i32; window_size + 1];
425 let chainz: Vec<u32> = (0..window_size as u32).collect();
426 let zeros = vec![0u32; window_size];
427
428 let mut special_dist_table = HashMap::new();
430 let num_special_distances = if distance_multiplier != 0 {
431 for i in (0..NUM_SPECIAL_DISTANCES).rev() {
433 let dist = special_distance(i, distance_multiplier);
434 if dist > 0 {
435 special_dist_table.insert(dist, i);
436 }
437 }
438 NUM_SPECIAL_DISTANCES
439 } else {
440 0
441 };
442
443 Self {
444 data,
445 size,
446 window_size,
447 window_mask: window_size - 1,
448 min_length,
449 max_length,
450 hash_num_values,
451 hash_mask,
452 hash_shift,
453 head,
454 chain,
455 val,
456 headz,
457 chainz,
458 zeros,
459 numzeros: 0,
460 special_dist_table,
461 num_special_distances,
462 max_chain_length: 256,
463 }
464 }
465
466 fn get_hash(&self, pos: usize) -> u32 {
468 if pos + 2 >= self.size {
469 return 0;
470 }
471 let mut result = 0u32;
472 result ^= self.data[pos] & 0xFFFF;
473 result ^= (self.data[pos + 1] & 0xFFFF) << self.hash_shift;
474 result ^= (self.data[pos + 2] & 0xFFFF) << (self.hash_shift * 2);
475 result & self.hash_mask as u32
476 }
477
478 fn count_zeros(&self, pos: usize, prev_zeros: u32) -> u32 {
480 let end = (pos + self.window_size).min(self.size);
481 if prev_zeros > 0 {
482 if prev_zeros >= self.window_mask as u32
483 && self.data[end - 1] == 0
484 && end == pos + self.window_size
485 {
486 return prev_zeros;
487 } else {
488 return prev_zeros - 1;
489 }
490 }
491 let mut num = 0u32;
492 while pos + (num as usize) < end && self.data[pos + (num as usize)] == 0 {
493 num += 1;
494 }
495 num
496 }
497
498 fn update(&mut self, pos: usize) {
500 let hashval = self.get_hash(pos);
501 let wpos = pos & self.window_mask;
502
503 self.val[wpos] = hashval as i32;
504 if self.head[hashval as usize] != -1 {
505 self.chain[wpos] = self.head[hashval as usize] as u32;
506 }
507 self.head[hashval as usize] = wpos as i32;
508
509 if pos > 0 && self.data[pos] != self.data[pos - 1] {
511 self.numzeros = 0;
512 }
513 self.numzeros = self.count_zeros(pos, self.numzeros);
514
515 self.zeros[wpos] = self.numzeros;
516 if self.headz[self.numzeros as usize] != -1 {
517 self.chainz[wpos] = self.headz[self.numzeros as usize] as u32;
518 }
519 self.headz[self.numzeros as usize] = wpos as i32;
520 }
521
522 fn update_range(&mut self, pos: usize, len: usize) {
524 for i in 0..len {
525 self.update(pos + i);
526 }
527 }
528
529 fn find_match(&self, pos: usize, max_dist: usize) -> (usize, usize) {
532 let mut best_dist_symbol = 0usize;
533 let mut best_len = 1usize;
534
535 self.find_matches(pos, max_dist, |len, dist_symbol| {
536 if len > best_len || (len == best_len && dist_symbol < best_dist_symbol) {
537 best_len = len;
538 best_dist_symbol = dist_symbol;
539 }
540 });
541
542 (best_dist_symbol, best_len)
543 }
544
545 fn find_matches<F>(&self, pos: usize, max_dist: usize, mut found_match: F)
547 where
548 F: FnMut(usize, usize),
549 {
550 let wpos = pos & self.window_mask;
551 let hashval = self.get_hash(pos);
552 let mut hashpos = self.chain[wpos];
553
554 let mut prev_dist = 0i32;
555 let end = (pos + self.max_length).min(self.size);
556 let mut chain_length = 0u32;
557 let mut best_len = 0usize;
558
559 loop {
560 let dist = if hashpos as usize <= wpos {
562 wpos - hashpos as usize
563 } else {
564 wpos + self.window_mask + 1 - hashpos as usize
565 };
566
567 if (dist as i32) < prev_dist {
568 break;
569 }
570 prev_dist = dist as i32;
571
572 if dist > 0 && dist <= max_dist {
573 let mut i = pos;
575 let mut j = pos - dist;
576
577 if self.numzeros > 3 {
579 let r =
580 ((self.numzeros - 1) as usize).min(self.zeros[hashpos as usize] as usize);
581 let skip = if i + r >= end { end - i - 1 } else { r };
582 i += skip;
583 j += skip;
584 }
585
586 while i < end && self.data[i] == self.data[j] {
588 i += 1;
589 j += 1;
590 }
591
592 let len = i - pos;
593
594 if len >= self.min_length && len + 2 >= best_len {
596 let dist_symbol =
597 if let Some(&sym) = self.special_dist_table.get(&(dist as i32)) {
598 sym
599 } else {
600 self.num_special_distances + dist - 1
601 };
602 found_match(len, dist_symbol);
603 if len > best_len {
604 best_len = len;
605 }
606 }
607 }
608
609 chain_length += 1;
610 if chain_length >= self.max_chain_length {
611 break;
612 }
613
614 if self.numzeros >= 3 && best_len > self.numzeros as usize {
616 if hashpos == self.chainz[hashpos as usize] {
618 break;
619 }
620 hashpos = self.chainz[hashpos as usize];
621 if self.zeros[hashpos as usize] != self.numzeros {
622 break;
623 }
624 } else {
625 if hashpos == self.chain[hashpos as usize] {
627 break;
628 }
629 hashpos = self.chain[hashpos as usize];
630 if self.val[hashpos as usize] != hashval as i32 {
631 break;
633 }
634 }
635 }
636 }
637}
638
639pub fn apply_lz77_backref(
648 tokens: &[Token],
649 num_contexts: usize,
650 force_huffman: bool,
651 distance_multiplier: i32,
652) -> Option<(Vec<Token>, Lz77Params)> {
653 if tokens.is_empty() {
654 return None;
655 }
656
657 let mut lz77 = Lz77Params::new(num_contexts, force_huffman);
658
659 let sce = SymbolCostEstimator::new(num_contexts, force_huffman, tokens, &lz77);
661
662 let mut sym_cost = vec![0.0f32; tokens.len() + 1];
664 for (i, token) in tokens.iter().enumerate() {
665 let e = UintCoder::encode(token.value);
666 let cost = sce.symbol_cost(token.context() as usize, e.token as usize) + e.nbits as f32;
667 sym_cost[i + 1] = sym_cost[i] + cost;
668 }
669
670 let mut out = Vec::with_capacity(tokens.len());
671 let mut bit_decrease: f32 = 0.0;
672 let total_symbols = tokens.len();
673
674 let max_distance = tokens.len();
675 let min_length = lz77.min_length as usize;
676 let max_length = tokens.len();
677
678 let mut window_size = 1usize;
680 while window_size < max_distance && window_size < WINDOW_SIZE {
681 window_size <<= 1;
682 }
683
684 let mut chain = HashChain::new(
685 tokens,
686 window_size,
687 min_length,
688 max_length,
689 distance_multiplier,
690 );
691
692 const MAX_LAZY_MATCH_LEN: usize = 256;
693 let mut already_updated = false;
694
695 let mut i = 0usize;
696 while i < tokens.len() {
697 out.push(tokens[i]);
698
699 if !already_updated {
700 chain.update(i);
701 }
702 already_updated = false;
703
704 let (mut dist_symbol, mut len) = chain.find_match(i, max_distance);
705
706 if len >= min_length {
707 if len < MAX_LAZY_MATCH_LEN && i + 1 < tokens.len() {
709 chain.update(i + 1);
710 already_updated = true;
711 let (dist_symbol2, len2) = chain.find_match(i + 1, max_distance);
712 if len2 > len {
713 i += 1;
716 already_updated = false;
717 len = len2;
718 dist_symbol = dist_symbol2;
719 out.push(tokens[i]);
720 }
721 }
722
723 let literal_cost = sym_cost[i + len] - sym_cost[i];
725 let lz77_len = len - min_length;
726
727 let lz77_cost = len_cost(lz77_len as u32)
729 + dist_cost(dist_symbol as u32)
730 + sce.add_symbol_cost(out.last().unwrap().context() as usize);
731
732 if lz77_cost <= literal_cost {
733 let last_token = out.last_mut().unwrap();
735 last_token.value = lz77_len as u32;
736 last_token.set_lz77_length(true);
737
738 out.push(Token::new(lz77.distance_context, dist_symbol as u32));
739
740 bit_decrease += literal_cost - lz77_cost;
741 } else {
742 for j in 1..len {
744 out.push(tokens[i + j]);
745 }
746 }
747
748 if already_updated {
750 chain.update_range(i + 2, len - 2);
751 already_updated = false;
752 } else {
753 chain.update_range(i + 1, len - 1);
754 }
755 i += len - 1;
756 }
757 i += 1;
760 }
761
762 let threshold = total_symbols as f32 * 0.2 + 16.0;
764 #[cfg(feature = "debug-tokens")]
765 eprintln!(
766 "[LZ77-backref] bit_decrease={:.1}, threshold={:.1}, tokens: {} -> {}, matches={}",
767 bit_decrease,
768 threshold,
769 total_symbols,
770 out.len(),
771 out.iter().filter(|t| t.is_lz77_length()).count()
772 );
773 if bit_decrease > threshold {
774 lz77.enabled = true;
775 Some((out, lz77))
776 } else {
777 None
778 }
779}
780
781pub fn apply_lz77_rle(
791 tokens: &[Token],
792 num_contexts: usize,
793 force_huffman: bool,
794 distance_multiplier: i32,
795) -> Option<(Vec<Token>, Lz77Params)> {
796 if tokens.is_empty() {
797 return None;
798 }
799
800 let mut lz77 = Lz77Params::new(num_contexts, force_huffman);
801
802 let rle_distance_symbol: u32 = if distance_multiplier > 0 { 1 } else { 0 };
808
809 let sce = SymbolCostEstimator::new(num_contexts, force_huffman, tokens, &lz77);
812
813 let mut sym_cost = vec![0.0f32; tokens.len() + 1];
815 for (i, token) in tokens.iter().enumerate() {
816 let e = UintCoder::encode(token.value);
817 let cost = sce.symbol_cost(token.context() as usize, e.token as usize) + e.nbits as f32;
818 sym_cost[i + 1] = sym_cost[i] + cost;
819 }
820
821 let mut out = Vec::with_capacity(tokens.len());
822 let mut bit_decrease: f32 = 0.0;
823 let total_symbols = tokens.len();
824
825 let mut i = 0;
826 while i < tokens.len() {
827 let mut num_to_copy = 0;
830 if i > 0 {
831 let prev_value = tokens[i - 1].value;
832 while i + num_to_copy < tokens.len() && tokens[i + num_to_copy].value == prev_value {
833 num_to_copy += 1;
834 }
835 }
836
837 if num_to_copy == 0 {
838 out.push(tokens[i]);
839 i += 1;
840 continue;
841 }
842
843 let literal_cost = sym_cost[i + num_to_copy] - sym_cost[i];
845
846 let lz77_cost = if num_to_copy >= lz77.min_length as usize {
848 let lz77_len = num_to_copy - lz77.min_length as usize;
849 ceil_log2_nonzero((lz77_len + 1) as u32) as f32 + 1.0
851 } else {
852 0.0
853 };
854
855 if num_to_copy < lz77.min_length as usize || literal_cost <= lz77_cost {
856 for j in 0..num_to_copy {
858 out.push(tokens[i + j]);
859 }
860 i += num_to_copy;
861 continue;
862 }
863
864 let lz77_len = (num_to_copy - lz77.min_length as usize) as u32;
866 out.push(Token::lz77_length(tokens[i].context(), lz77_len));
867
868 out.push(Token::new(lz77.distance_context, rle_distance_symbol));
870
871 bit_decrease += literal_cost - lz77_cost;
872 i += num_to_copy;
873 }
874
875 let threshold = total_symbols as f32 * 0.2 + 16.0;
877 #[cfg(feature = "debug-tokens")]
878 eprintln!(
879 "[LZ77-RLE] bit_decrease={:.1}, threshold={:.1}, tokens: {} -> {}, runs_found={}",
880 bit_decrease,
881 threshold,
882 total_symbols,
883 out.len(),
884 out.iter().filter(|t| t.is_lz77_length()).count()
885 );
886 if bit_decrease > threshold {
887 lz77.enabled = true;
888 Some((out, lz77))
889 } else {
890 None
891 }
892}
893
894fn ceil_log2_nonzero(x: u32) -> u32 {
896 debug_assert!(x > 0);
897 let floor = 31 - x.leading_zeros();
898 if x.is_power_of_two() {
899 floor
900 } else {
901 floor + 1
902 }
903}
904
905#[derive(Debug, Clone, Copy, PartialEq, Eq, Default)]
907pub enum Lz77Method {
908 #[default]
911 Rle,
912 Greedy,
916 Optimal,
920}
921
922pub fn apply_lz77(
934 tokens: &[Token],
935 num_contexts: usize,
936 force_huffman: bool,
937 method: Lz77Method,
938 distance_multiplier: i32,
939) -> Option<(Vec<Token>, Lz77Params)> {
940 match method {
941 Lz77Method::Rle => apply_lz77_rle(tokens, num_contexts, force_huffman, distance_multiplier),
942 Lz77Method::Greedy => {
943 apply_lz77_backref(tokens, num_contexts, force_huffman, distance_multiplier)
944 }
945 Lz77Method::Optimal => {
946 apply_lz77_optimal(tokens, num_contexts, force_huffman, distance_multiplier)
947 }
948 }
949}
950
951pub fn apply_lz77_optimal(
960 tokens: &[Token],
961 num_contexts: usize,
962 force_huffman: bool,
963 distance_multiplier: i32,
964) -> Option<(Vec<Token>, Lz77Params)> {
965 if tokens.is_empty() {
966 return None;
967 }
968
969 let greedy_result =
972 apply_lz77_backref(tokens, num_contexts, force_huffman, distance_multiplier);
973 let greedy_tokens = match &greedy_result {
974 Some((t, _)) => t,
975 None => return None,
976 };
977
978 let mut lz77 = Lz77Params::new(num_contexts, force_huffman);
979 lz77.enabled = true;
980
981 let sce = SymbolCostEstimator::new(num_contexts + 1, force_huffman, greedy_tokens, &lz77);
983
984 let mut sym_cost = vec![0.0f32; tokens.len() + 1];
986 for (i, token) in tokens.iter().enumerate() {
987 let e = UintCoder::encode(token.value);
988 let cost = sce.symbol_cost(token.context() as usize, e.token as usize) + e.nbits as f32;
989 sym_cost[i + 1] = sym_cost[i] + cost;
990 }
991
992 let max_distance = tokens.len();
994 let min_length = lz77.min_length as usize;
995 let max_length = tokens.len();
996
997 let mut window_size = 1usize;
998 while window_size < max_distance && window_size < WINDOW_SIZE {
999 window_size <<= 1;
1000 }
1001
1002 let mut chain = HashChain::new(
1003 tokens,
1004 window_size,
1005 min_length,
1006 max_length,
1007 distance_multiplier,
1008 );
1009
1010 struct PrefixInfo {
1012 len: u32,
1013 dist_symbol: u32, ctx: u32,
1015 total_cost: f32,
1016 }
1017
1018 let n = tokens.len();
1019 let mut prefix_costs: Vec<PrefixInfo> = (0..=n)
1020 .map(|_| PrefixInfo {
1021 len: 0,
1022 dist_symbol: 0,
1023 ctx: 0,
1024 total_cost: f32::MAX,
1025 })
1026 .collect();
1027 prefix_costs[0].total_cost = 0.0;
1028
1029 let mut rle_length = 0usize;
1030 let mut skip_lz77 = 0usize;
1031 let mut dist_symbols: Vec<u32> = Vec::new();
1032
1033 for i in 0..n {
1034 chain.update(i);
1035
1036 let lit_cost = prefix_costs[i].total_cost + sym_cost[i + 1] - sym_cost[i];
1038 if prefix_costs[i + 1].total_cost > lit_cost {
1039 prefix_costs[i + 1].dist_symbol = 0;
1040 prefix_costs[i + 1].len = 1;
1041 prefix_costs[i + 1].ctx = tokens[i].context();
1042 prefix_costs[i + 1].total_cost = lit_cost;
1043 }
1044
1045 if skip_lz77 > 0 {
1046 skip_lz77 -= 1;
1047 continue;
1048 }
1049
1050 dist_symbols.clear();
1052 chain.find_matches(i, max_distance, |len, dist_symbol| {
1053 if dist_symbols.len() <= len {
1054 dist_symbols.resize(len + 1, dist_symbol as u32);
1055 }
1056 if (dist_symbol as u32) < dist_symbols[len] {
1057 dist_symbols[len] = dist_symbol as u32;
1058 }
1059 });
1060
1061 if dist_symbols.len() <= min_length {
1062 continue;
1063 }
1064
1065 {
1067 let mut best_cost = dist_symbols[dist_symbols.len() - 1];
1068 for j in (min_length..dist_symbols.len()).rev() {
1069 if dist_symbols[j] < best_cost {
1070 best_cost = dist_symbols[j];
1071 }
1072 dist_symbols[j] = best_cost;
1073 }
1074 }
1075
1076 for (j, &dsym) in dist_symbols.iter().enumerate().skip(min_length) {
1078 let target = i + j;
1079 if target > n {
1080 break;
1081 }
1082 let lz77_cost =
1083 sce.len_cost(tokens[i].context() as usize, (j - min_length) as u32, &lz77)
1084 + sce.dist_cost_sce(dsym, &lz77);
1085 let cost = prefix_costs[i].total_cost + lz77_cost;
1086 if prefix_costs[target].total_cost > cost {
1087 prefix_costs[target].len = j as u32;
1088 prefix_costs[target].dist_symbol = dsym + 1; prefix_costs[target].ctx = tokens[i].context();
1090 prefix_costs[target].total_cost = cost;
1091 }
1092 }
1093
1094 let last_dist = dist_symbols[dist_symbols.len() - 1];
1096 if (last_dist == 0 && distance_multiplier == 0)
1097 || (last_dist == 1 && distance_multiplier != 0)
1098 {
1099 rle_length += 1;
1100 } else {
1101 rle_length = 0;
1102 }
1103 if rle_length >= 8 && dist_symbols.len() > 9 {
1104 skip_lz77 = dist_symbols.len() - 10;
1105 rle_length = 0;
1106 }
1107 }
1108
1109 let mut out = Vec::with_capacity(n);
1111 let mut pos = n;
1112 while pos > 0 {
1113 let info = &prefix_costs[pos];
1114 let is_lz77 = info.dist_symbol != 0;
1115
1116 if is_lz77 {
1117 let dist_symbol = info.dist_symbol - 1;
1118 out.push(Token::new(lz77.distance_context, dist_symbol));
1119 }
1120
1121 let val = if is_lz77 {
1122 info.len - min_length as u32
1123 } else {
1124 tokens[pos - 1].value
1125 };
1126 let mut tok = Token::new(info.ctx, val);
1127 tok.set_lz77_length(is_lz77);
1128 out.push(tok);
1129
1130 pos -= info.len as usize;
1131 }
1132
1133 out.reverse();
1134 Some((out, lz77))
1135}
1136
1137#[allow(dead_code)] pub fn apply_lz77_best(
1144 tokens: &[Token],
1145 num_contexts: usize,
1146 force_huffman: bool,
1147 distance_multiplier: i32,
1148) -> Option<(Vec<Token>, Lz77Params)> {
1149 let rle_result = apply_lz77_rle(tokens, num_contexts, force_huffman, distance_multiplier);
1150 let backref_result =
1151 apply_lz77_backref(tokens, num_contexts, force_huffman, distance_multiplier);
1152
1153 match (&rle_result, &backref_result) {
1154 (Some((rle_tokens, _)), Some((backref_tokens, _))) => {
1155 if backref_tokens.len() <= rle_tokens.len() {
1157 backref_result
1158 } else {
1159 rle_result
1160 }
1161 }
1162 (Some(_), None) => rle_result,
1163 (None, Some(_)) => backref_result,
1164 (None, None) => None,
1165 }
1166}
1167
1168#[cfg(test)]
1169mod tests {
1170 use super::*;
1171
1172 #[test]
1173 fn test_ceil_log2_nonzero() {
1174 assert_eq!(ceil_log2_nonzero(1), 0);
1175 assert_eq!(ceil_log2_nonzero(2), 1);
1176 assert_eq!(ceil_log2_nonzero(3), 2);
1177 assert_eq!(ceil_log2_nonzero(4), 2);
1178 assert_eq!(ceil_log2_nonzero(5), 3);
1179 assert_eq!(ceil_log2_nonzero(8), 3);
1180 assert_eq!(ceil_log2_nonzero(9), 4);
1181 }
1182
1183 #[test]
1184 fn test_no_rle_on_short_stream() {
1185 let tokens = vec![Token::new(0, 5), Token::new(0, 5), Token::new(0, 5)];
1187 assert!(apply_lz77_rle(&tokens, 1, false, 0).is_none());
1188 }
1189
1190 #[test]
1191 fn test_rle_on_long_run() {
1192 let mut tokens = Vec::new();
1194 tokens.push(Token::new(0, 5));
1196 for _ in 0..200 {
1197 tokens.push(Token::new(0, 5));
1198 }
1199
1200 let result = apply_lz77_rle(&tokens, 1, false, 0);
1201 if let Some((lz77_tokens, params)) = result {
1202 assert!(params.enabled);
1203 assert!(lz77_tokens.len() < tokens.len());
1205 assert!(lz77_tokens.iter().any(|t| t.is_lz77_length()));
1207 }
1208 }
1210
1211 #[test]
1212 fn test_rle_preserves_non_runs() {
1213 let mut tokens = Vec::new();
1215 for i in 0..10 {
1217 tokens.push(Token::new(0, i));
1218 }
1219 for _ in 0..100 {
1221 tokens.push(Token::new(0, 42));
1222 }
1223 for i in 0..10 {
1225 tokens.push(Token::new(0, i + 100));
1226 }
1227
1228 if let Some((lz77_tokens, params)) = apply_lz77_rle(&tokens, 1, false, 0) {
1229 assert!(params.enabled);
1230 assert!(lz77_tokens.len() < tokens.len());
1231 assert_eq!(lz77_tokens[0].value, 0);
1233 assert!(!lz77_tokens[0].is_lz77_length());
1234 }
1235 }
1236
1237 #[test]
1238 fn test_empty_stream() {
1239 assert!(apply_lz77_rle(&[], 1, false, 0).is_none());
1240 }
1241
1242 #[test]
1245 fn test_backref_empty_stream() {
1246 assert!(apply_lz77_backref(&[], 1, false, 0).is_none());
1247 }
1248
1249 #[test]
1250 fn test_backref_short_stream() {
1251 let tokens = vec![Token::new(0, 5), Token::new(0, 5), Token::new(0, 5)];
1253 assert!(apply_lz77_backref(&tokens, 1, false, 0).is_none());
1254 }
1255
1256 #[test]
1257 fn test_backref_on_repeating_pattern() {
1258 let mut tokens = Vec::new();
1261 for _ in 0..100 {
1262 tokens.push(Token::new(0, 10));
1263 tokens.push(Token::new(0, 20));
1264 tokens.push(Token::new(0, 30));
1265 }
1266
1267 let result = apply_lz77_backref(&tokens, 1, false, 0);
1268 if let Some((lz77_tokens, params)) = result {
1269 assert!(params.enabled);
1270 assert!(
1272 lz77_tokens.len() < tokens.len(),
1273 "backref should compress pattern: {} vs {}",
1274 lz77_tokens.len(),
1275 tokens.len()
1276 );
1277 assert!(lz77_tokens.iter().any(|t| t.is_lz77_length()));
1279 }
1280 }
1281
1282 #[test]
1283 fn test_backref_finds_longer_matches_than_rle() {
1284 let mut tokens = Vec::new();
1287 for _ in 0..50 {
1288 for j in 1..=5 {
1289 tokens.push(Token::new(0, j));
1290 }
1291 }
1292
1293 let rle_result = apply_lz77_rle(&tokens, 1, false, 0);
1294 let backref_result = apply_lz77_backref(&tokens, 1, false, 0);
1295
1296 match (&rle_result, &backref_result) {
1299 (None, Some((backref_tokens, _))) => {
1300 assert!(backref_tokens.len() < tokens.len());
1302 }
1303 (Some((rle_tokens, _)), Some((backref_tokens, _))) => {
1304 assert!(backref_tokens.len() <= rle_tokens.len());
1306 }
1307 _ => {
1308 }
1310 }
1311 }
1312
1313 #[test]
1314 fn test_backref_with_distance_multiplier() {
1315 let mut tokens = Vec::new();
1318 let image_width = 64;
1319
1320 for _row in 0..20 {
1322 for col in 0..image_width {
1323 tokens.push(Token::new(0, (col % 16) as u32));
1325 }
1326 }
1327
1328 let _result_no_mult = apply_lz77_backref(&tokens, 1, false, 0);
1329 let result_with_mult = apply_lz77_backref(&tokens, 1, false, image_width);
1330
1331 if let Some((tokens_mult, params)) = result_with_mult {
1334 assert!(params.enabled);
1335 assert!(tokens_mult.len() < tokens.len());
1336 }
1337 }
1338
1339 #[test]
1340 fn test_special_distance() {
1341 assert_eq!(special_distance(0, 64), 64);
1344 assert_eq!(special_distance(1, 64), 1);
1346 assert_eq!(special_distance(2, 64), 65);
1348 assert_eq!(special_distance(3, 64), 63);
1350 }
1351
1352 #[test]
1353 fn test_len_cost() {
1354 for len in 0..1000 {
1356 let cost = len_cost(len);
1357 assert!(cost >= 0.0, "len_cost({}) should be non-negative", len);
1358 assert!(cost < 100.0, "len_cost({}) should be reasonable", len);
1359 }
1360 }
1361
1362 #[test]
1363 fn test_dist_cost() {
1364 for dist in 0..10000 {
1366 let cost = dist_cost(dist);
1367 assert!(cost >= 0.0, "dist_cost({}) should be non-negative", dist);
1368 assert!(cost < 100.0, "dist_cost({}) should be reasonable", dist);
1369 }
1370 }
1371
1372 #[test]
1373 fn test_apply_lz77_method_enum() {
1374 let mut tokens = Vec::new();
1375 tokens.push(Token::new(0, 5));
1376 for _ in 0..200 {
1377 tokens.push(Token::new(0, 5));
1378 }
1379
1380 let rle_result = apply_lz77(&tokens, 1, false, Lz77Method::Rle, 0);
1382 if let Some((_, params)) = &rle_result {
1383 assert!(params.enabled);
1384 }
1385
1386 let greedy_result = apply_lz77(&tokens, 1, false, Lz77Method::Greedy, 0);
1388 if let Some((_, params)) = &greedy_result {
1389 assert!(params.enabled);
1390 }
1391 }
1392
1393 #[test]
1394 fn test_apply_lz77_best() {
1395 let mut tokens = Vec::new();
1397 for _ in 0..50 {
1398 for j in 1..=10 {
1399 tokens.push(Token::new(0, j));
1400 }
1401 }
1402
1403 let best_result = apply_lz77_best(&tokens, 1, false, 0);
1404 if let Some((best_tokens, params)) = best_result {
1406 assert!(params.enabled);
1407 assert!(best_tokens.len() < tokens.len());
1408 }
1409 }
1410
1411 #[test]
1412 fn test_hash_chain_basic() {
1413 let tokens = vec![
1415 Token::new(0, 10),
1416 Token::new(0, 20),
1417 Token::new(0, 30),
1418 Token::new(0, 40), Token::new(0, 10),
1420 Token::new(0, 20),
1421 Token::new(0, 30), ];
1423
1424 let mut chain = HashChain::new(&tokens, 16, 3, 100, 0);
1425 for i in 0..tokens.len() {
1427 chain.update(i);
1428 }
1429
1430 let (dist_symbol, len) = chain.find_match(4, 10);
1432 assert!(len >= 3, "should find match of length >= 3, got {}", len);
1433 assert_eq!(dist_symbol, 3, "distance symbol for dist=4 should be 3");
1437 }
1438}