1use hashbrown::HashMap;
16
17use super::token::{Lz77UintCoder, Token, UintCoder};
18
19const WINDOW_SIZE: usize = 1 << 20;
21
22const NUM_SPECIAL_DISTANCES: usize = 120;
24
25#[rustfmt::skip]
29const SPECIAL_DISTANCES: [[i8; 2]; NUM_SPECIAL_DISTANCES] = [
30 [0, 1], [1, 0], [1, 1], [-1, 1], [0, 2], [2, 0], [1, 2], [-1, 2],
31 [2, 1], [-2, 1], [2, 2], [-2, 2], [0, 3], [3, 0], [1, 3], [-1, 3],
32 [3, 1], [-3, 1], [2, 3], [-2, 3], [3, 2], [-3, 2], [0, 4], [4, 0],
33 [1, 4], [-1, 4], [4, 1], [-4, 1], [3, 3], [-3, 3], [2, 4], [-2, 4],
34 [4, 2], [-4, 2], [0, 5], [3, 4], [-3, 4], [4, 3], [-4, 3], [5, 0],
35 [1, 5], [-1, 5], [5, 1], [-5, 1], [2, 5], [-2, 5], [5, 2], [-5, 2],
36 [4, 4], [-4, 4], [3, 5], [-3, 5], [5, 3], [-5, 3], [0, 6], [6, 0],
37 [1, 6], [-1, 6], [6, 1], [-6, 1], [2, 6], [-2, 6], [6, 2], [-6, 2],
38 [4, 5], [-4, 5], [5, 4], [-5, 4], [3, 6], [-3, 6], [6, 3], [-6, 3],
39 [0, 7], [7, 0], [1, 7], [-1, 7], [5, 5], [-5, 5], [7, 1], [-7, 1],
40 [4, 6], [-4, 6], [6, 4], [-6, 4], [2, 7], [-2, 7], [7, 2], [-7, 2],
41 [3, 7], [-3, 7], [7, 3], [-7, 3], [5, 6], [-5, 6], [6, 5], [-6, 5],
42 [8, 0], [4, 7], [-4, 7], [7, 4], [-7, 4], [8, 1], [8, 2], [6, 6],
43 [-6, 6], [8, 3], [5, 7], [-5, 7], [7, 5], [-7, 5], [8, 4], [6, 7],
44 [-6, 7], [7, 6], [-7, 6], [8, 5], [7, 7], [-7, 7], [8, 6], [8, 7],
45];
46
47#[inline]
49fn special_distance(index: usize, multiplier: i32) -> i32 {
50 SPECIAL_DISTANCES[index][0] as i32 + multiplier * SPECIAL_DISTANCES[index][1] as i32
51}
52
53#[rustfmt::skip]
56#[allow(clippy::excessive_precision)]
57const LEN_COST_TABLE: [f32; 17] = [
58 2.797667318563126, 3.213177690381199, 2.5706009246743737,
59 2.408392498667534, 2.829649191872326, 3.3923087753324577,
60 4.029267451554331, 4.415576699706408, 4.509357574741465,
61 9.21481543803004, 10.020590190114898, 11.858671627804766,
62 12.45853300490526, 11.713105831990857, 12.561996324849314,
63 13.775477692278367, 13.174027068768641,
64];
65
66#[rustfmt::skip]
69#[allow(clippy::excessive_precision)]
70const DIST_COST_TABLE: [f32; 128] = [
71 6.368282626312716, 5.680793277090298, 8.347404197105247,
72 7.641619201599141, 6.914328374119438, 7.959808291537444,
73 8.70023120759855, 8.71378518934703, 9.379132523982769,
74 9.110472749092708, 9.159029569270908, 9.430936766731973,
75 7.278284055315169, 7.8278514904267755, 10.026641158289236,
76 9.976049229827066, 9.64351607048908, 9.563403863480442,
77 10.171474111762747, 10.45950155077234, 9.994813912104219,
78 10.322524683741156, 8.465808729388186, 8.756254166066853,
79 10.160930174662234, 10.247329273413435, 10.04090403724809,
80 10.129398517544082, 9.342311691539546, 9.07608009102374,
81 10.104799540677513, 10.378079384990906, 10.165828974075072,
82 10.337595322341553, 7.940557464567944, 10.575665823319431,
83 11.023344321751955, 10.736144698831827, 11.118277044595054,
84 7.468468230648442, 10.738305230932939, 10.906980780216568,
85 10.163468216353817, 10.17805759656433, 11.167283670483565,
86 11.147050200274544, 10.517921919244333, 10.651764778156886,
87 10.17074446448919, 11.217636876224745, 11.261630721139484,
88 11.403140815247259, 10.892472096873417, 11.1859607804481,
89 8.017346947551262, 7.895143720278828, 11.036577113822025,
90 11.170562110315794, 10.326988722591086, 10.40872184751056,
91 11.213498225466386, 11.30580635516863, 10.672272515665442,
92 10.768069466228063, 11.145257364153565, 11.64668307145549,
93 10.593156194627339, 11.207499484844943, 10.767517766396908,
94 10.826629811407042, 10.737764794499988, 10.6200448518045,
95 10.191315385198092, 8.468384171390085, 11.731295299170432,
96 11.824619886654398, 10.41518844301179, 10.16310536548649,
97 10.539423685097576, 10.495136599328031, 10.469112847728267,
98 11.72057686174922, 10.910326337834674, 11.378921834673758,
99 11.847759036098536, 11.92071647623854, 10.810628276345282,
100 11.008601085273893, 11.910326337834674, 11.949212023423133,
101 11.298614839104337, 11.611603659010392, 10.472930394619985,
102 11.835564720850282, 11.523267392285337, 12.01055816679611,
103 8.413029688994023, 11.895784139536406, 11.984679534970505,
104 11.220654278717394, 11.716311684833672, 10.61036646226114,
105 10.89849965960364, 10.203762898863669, 10.997560826267238,
106 11.484217379438984, 11.792836176993665, 12.24310468755171,
107 11.464858097919262, 12.212747017409377, 11.425595666074955,
108 11.572048533398757, 12.742093965163013, 11.381874288645637,
109 12.191870445817015, 11.683156920035426, 11.152442115262197,
110 11.90303691580457, 11.653292787169159, 11.938615382266098,
111 16.970641701570223, 16.853602280380002, 17.26240782594733,
112 16.644655390108507, 17.14310889757499, 16.910935455445955,
113 17.505678976959697, 17.213498225466388,
114];
115
116fn len_cost(len: u32) -> f32 {
118 let (tok, nbits) = if len == 0 {
120 (0u32, 0u32)
121 } else {
122 let n = 31 - len.leading_zeros();
123 (1 + n, n)
124 };
125 let table_size = LEN_COST_TABLE.len();
126 let tok_idx = (tok as usize).min(table_size - 1);
127 LEN_COST_TABLE[tok_idx] + nbits as f32
128}
129
130fn dist_cost(dist: u32) -> f32 {
132 let (tok, nbits) = hybrid_uint_encode_7_0_0(dist);
134 let table_size = DIST_COST_TABLE.len();
135 let tok_idx = (tok as usize).min(table_size - 1);
136 DIST_COST_TABLE[tok_idx] + nbits as f32
137}
138
139fn hybrid_uint_encode_7_0_0(value: u32) -> (u32, u32) {
141 if value < 7 {
145 (value, 0)
146 } else {
147 let n = 31 - value.leading_zeros();
148 let tok = 7 + n - 3; (tok, n)
150 }
151}
152
153#[derive(Debug, Clone)]
155pub struct Lz77Params {
156 pub enabled: bool,
157 pub min_symbol: u32,
160 pub min_length: u32,
162 pub distance_context: u32,
164}
165
166impl Lz77Params {
167 pub fn new(num_contexts: usize, force_huffman: bool) -> Self {
168 Self {
169 enabled: false,
170 min_symbol: if force_huffman { 512 } else { 224 },
171 min_length: 3,
172 distance_context: num_contexts as u32,
173 }
174 }
175}
176
177struct SymbolCostEstimator {
179 bits: Vec<f32>,
181 max_alphabet_size: usize,
182}
183
184impl SymbolCostEstimator {
185 fn new(num_contexts: usize, force_huffman: bool, tokens: &[Token], lz77: &Lz77Params) -> Self {
186 const ANS_LOG_TAB_SIZE: f32 = 12.0;
187
188 let mut counts: Vec<Vec<u32>> = vec![vec![]; num_contexts];
190 let mut total_counts = vec![0u32; num_contexts];
191
192 for token in tokens {
193 let (tok, _nbits) = if token.is_lz77_length {
194 let e = Lz77UintCoder::encode(token.value);
195 (e.token + lz77.min_symbol, e.nbits)
196 } else {
197 let e = UintCoder::encode(token.value);
198 (e.token, e.nbits)
199 };
200 let ctx = token.context as usize;
201 if ctx < num_contexts {
202 let sym = tok as usize;
203 if sym >= counts[ctx].len() {
204 counts[ctx].resize(sym + 1, 0);
205 }
206 counts[ctx][sym] += 1;
207 total_counts[ctx] += 1;
208 }
209 }
210
211 let max_alphabet_size = counts.iter().map(|c| c.len()).max().unwrap_or(0);
212 let mut bits = vec![0.0f32; num_contexts * max_alphabet_size];
213
214 for ctx in 0..num_contexts {
215 let total = total_counts[ctx];
216 if total == 0 {
217 continue;
218 }
219 let inv_total = 1.0 / (total as f32 + 1e-8);
220 for sym in 0..counts[ctx].len() {
221 let cnt = counts[ctx][sym];
222 let cost = if cnt != 0 && cnt != total {
223 let p = cnt as f32 * inv_total;
224 let c = -p.log2();
225 if force_huffman { c.ceil() } else { c }
226 } else if cnt == 0 {
227 ANS_LOG_TAB_SIZE } else {
229 0.0 };
231 bits[ctx * max_alphabet_size + sym] = cost;
232 }
233 }
234
235 Self {
236 bits,
237 max_alphabet_size,
238 }
239 }
240
241 #[inline]
242 fn symbol_cost(&self, ctx: usize, sym: usize) -> f32 {
243 if sym < self.max_alphabet_size {
244 self.bits[ctx * self.max_alphabet_size + sym]
245 } else {
246 12.0 }
248 }
249
250 fn add_symbol_cost(&self, ctx: usize) -> f32 {
252 let mut total_cost = 0.0f32;
254 let mut total_count = 0u32;
255 for sym in 0..self.max_alphabet_size {
256 let cost = self.bits[ctx * self.max_alphabet_size + sym];
257 if cost < 12.0 {
258 total_cost += cost;
260 total_count += 1;
261 }
262 }
263 if total_count == 0 {
264 return 0.0;
265 }
266 (6.0 - total_cost / total_count as f32).max(0.0)
268 }
269
270 #[allow(dead_code)] fn len_cost(&self, ctx: usize, len: u32, lz77: &Lz77Params) -> f32 {
273 let (tok, nbits) = if len == 0 {
275 (0u32, 0u32)
276 } else {
277 let n = 31 - len.leading_zeros();
278 (1 + n, n)
279 };
280 let sym = tok + lz77.min_symbol;
281 nbits as f32 + self.symbol_cost(ctx, sym as usize)
282 }
283
284 #[allow(dead_code)] fn dist_cost_sce(&self, dist_symbol: u32, lz77: &Lz77Params) -> f32 {
287 let (tok, nbits) = UintCoder::encode(dist_symbol).into();
288 nbits as f32 + self.symbol_cost(lz77.distance_context as usize, tok as usize)
289 }
290}
291
292struct HashChain {
297 data: Vec<u32>,
299 size: usize,
301 window_size: usize,
303 window_mask: usize,
305 min_length: usize,
307 max_length: usize,
309
310 #[allow(dead_code)] hash_num_values: usize,
313 hash_mask: usize,
314 hash_shift: u32,
315
316 head: Vec<i32>,
318 chain: Vec<u32>,
320 val: Vec<i32>,
322
323 headz: Vec<i32>,
326 chainz: Vec<u32>,
328 zeros: Vec<u32>,
330 numzeros: u32,
332
333 special_dist_table: HashMap<i32, usize>,
335 num_special_distances: usize,
337
338 max_chain_length: u32,
340}
341
342impl HashChain {
343 fn new(
344 tokens: &[Token],
345 window_size: usize,
346 min_length: usize,
347 max_length: usize,
348 distance_multiplier: i32,
349 ) -> Self {
350 let size = tokens.len();
351
352 let data: Vec<u32> = tokens.iter().map(|t| t.value).collect();
354
355 let hash_num_values = 32768usize;
357 let hash_mask = hash_num_values - 1;
358 let hash_shift = 5u32;
359
360 let head = vec![-1i32; hash_num_values];
361 let chain: Vec<u32> = (0..window_size as u32).collect(); let val = vec![-1i32; window_size];
363
364 let headz = vec![-1i32; window_size + 1];
366 let chainz: Vec<u32> = (0..window_size as u32).collect();
367 let zeros = vec![0u32; window_size];
368
369 let mut special_dist_table = HashMap::new();
371 let num_special_distances = if distance_multiplier != 0 {
372 for i in (0..NUM_SPECIAL_DISTANCES).rev() {
374 let dist = special_distance(i, distance_multiplier);
375 if dist > 0 {
376 special_dist_table.insert(dist, i);
377 }
378 }
379 NUM_SPECIAL_DISTANCES
380 } else {
381 0
382 };
383
384 Self {
385 data,
386 size,
387 window_size,
388 window_mask: window_size - 1,
389 min_length,
390 max_length,
391 hash_num_values,
392 hash_mask,
393 hash_shift,
394 head,
395 chain,
396 val,
397 headz,
398 chainz,
399 zeros,
400 numzeros: 0,
401 special_dist_table,
402 num_special_distances,
403 max_chain_length: 256,
404 }
405 }
406
407 fn get_hash(&self, pos: usize) -> u32 {
409 if pos + 2 >= self.size {
410 return 0;
411 }
412 let mut result = 0u32;
413 result ^= self.data[pos] & 0xFFFF;
414 result ^= (self.data[pos + 1] & 0xFFFF) << self.hash_shift;
415 result ^= (self.data[pos + 2] & 0xFFFF) << (self.hash_shift * 2);
416 result & self.hash_mask as u32
417 }
418
419 fn count_zeros(&self, pos: usize, prev_zeros: u32) -> u32 {
421 let end = (pos + self.window_size).min(self.size);
422 if prev_zeros > 0 {
423 if prev_zeros >= self.window_mask as u32
424 && self.data[end - 1] == 0
425 && end == pos + self.window_size
426 {
427 return prev_zeros;
428 } else {
429 return prev_zeros - 1;
430 }
431 }
432 let mut num = 0u32;
433 while pos + (num as usize) < end && self.data[pos + (num as usize)] == 0 {
434 num += 1;
435 }
436 num
437 }
438
439 fn update(&mut self, pos: usize) {
441 let hashval = self.get_hash(pos);
442 let wpos = pos & self.window_mask;
443
444 self.val[wpos] = hashval as i32;
445 if self.head[hashval as usize] != -1 {
446 self.chain[wpos] = self.head[hashval as usize] as u32;
447 }
448 self.head[hashval as usize] = wpos as i32;
449
450 if pos > 0 && self.data[pos] != self.data[pos - 1] {
452 self.numzeros = 0;
453 }
454 self.numzeros = self.count_zeros(pos, self.numzeros);
455
456 self.zeros[wpos] = self.numzeros;
457 if self.headz[self.numzeros as usize] != -1 {
458 self.chainz[wpos] = self.headz[self.numzeros as usize] as u32;
459 }
460 self.headz[self.numzeros as usize] = wpos as i32;
461 }
462
463 fn update_range(&mut self, pos: usize, len: usize) {
465 for i in 0..len {
466 self.update(pos + i);
467 }
468 }
469
470 fn find_match(&self, pos: usize, max_dist: usize) -> (usize, usize) {
473 let mut best_dist_symbol = 0usize;
474 let mut best_len = 1usize;
475
476 self.find_matches(pos, max_dist, |len, dist_symbol| {
477 if len > best_len || (len == best_len && dist_symbol < best_dist_symbol) {
478 best_len = len;
479 best_dist_symbol = dist_symbol;
480 }
481 });
482
483 (best_dist_symbol, best_len)
484 }
485
486 fn find_matches<F>(&self, pos: usize, max_dist: usize, mut found_match: F)
488 where
489 F: FnMut(usize, usize),
490 {
491 let wpos = pos & self.window_mask;
492 let hashval = self.get_hash(pos);
493 let mut hashpos = self.chain[wpos];
494
495 let mut prev_dist = 0i32;
496 let end = (pos + self.max_length).min(self.size);
497 let mut chain_length = 0u32;
498 let mut best_len = 0usize;
499
500 loop {
501 let dist = if hashpos as usize <= wpos {
503 wpos - hashpos as usize
504 } else {
505 wpos + self.window_mask + 1 - hashpos as usize
506 };
507
508 if (dist as i32) < prev_dist {
509 break;
510 }
511 prev_dist = dist as i32;
512
513 if dist > 0 && dist <= max_dist {
514 let mut i = pos;
516 let mut j = pos - dist;
517
518 if self.numzeros > 3 {
520 let r =
521 ((self.numzeros - 1) as usize).min(self.zeros[hashpos as usize] as usize);
522 let skip = if i + r >= end { end - i - 1 } else { r };
523 i += skip;
524 j += skip;
525 }
526
527 while i < end && self.data[i] == self.data[j] {
529 i += 1;
530 j += 1;
531 }
532
533 let len = i - pos;
534
535 if len >= self.min_length && len + 2 >= best_len {
537 let dist_symbol =
538 if let Some(&sym) = self.special_dist_table.get(&(dist as i32)) {
539 sym
540 } else {
541 self.num_special_distances + dist - 1
542 };
543 found_match(len, dist_symbol);
544 if len > best_len {
545 best_len = len;
546 }
547 }
548 }
549
550 chain_length += 1;
551 if chain_length >= self.max_chain_length {
552 break;
553 }
554
555 if self.numzeros >= 3 && best_len > self.numzeros as usize {
557 if hashpos == self.chainz[hashpos as usize] {
559 break;
560 }
561 hashpos = self.chainz[hashpos as usize];
562 if self.zeros[hashpos as usize] != self.numzeros {
563 break;
564 }
565 } else {
566 if hashpos == self.chain[hashpos as usize] {
568 break;
569 }
570 hashpos = self.chain[hashpos as usize];
571 if self.val[hashpos as usize] != hashval as i32 {
572 break;
574 }
575 }
576 }
577 }
578}
579
580pub fn apply_lz77_backref(
589 tokens: &[Token],
590 num_contexts: usize,
591 force_huffman: bool,
592 distance_multiplier: i32,
593) -> Option<(Vec<Token>, Lz77Params)> {
594 if tokens.is_empty() {
595 return None;
596 }
597
598 let mut lz77 = Lz77Params::new(num_contexts, force_huffman);
599
600 let sce = SymbolCostEstimator::new(num_contexts, force_huffman, tokens, &lz77);
602
603 let mut sym_cost = vec![0.0f32; tokens.len() + 1];
605 for (i, token) in tokens.iter().enumerate() {
606 let e = UintCoder::encode(token.value);
607 let cost = sce.symbol_cost(token.context as usize, e.token as usize) + e.nbits as f32;
608 sym_cost[i + 1] = sym_cost[i] + cost;
609 }
610
611 let mut out = Vec::with_capacity(tokens.len());
612 let mut bit_decrease: f32 = 0.0;
613 let total_symbols = tokens.len();
614
615 let max_distance = tokens.len();
616 let min_length = lz77.min_length as usize;
617 let max_length = tokens.len();
618
619 let mut window_size = 1usize;
621 while window_size < max_distance && window_size < WINDOW_SIZE {
622 window_size <<= 1;
623 }
624
625 let mut chain = HashChain::new(
626 tokens,
627 window_size,
628 min_length,
629 max_length,
630 distance_multiplier,
631 );
632
633 const MAX_LAZY_MATCH_LEN: usize = 256;
634 let mut already_updated = false;
635
636 let mut i = 0usize;
637 while i < tokens.len() {
638 out.push(tokens[i]);
639
640 if !already_updated {
641 chain.update(i);
642 }
643 already_updated = false;
644
645 let (mut dist_symbol, mut len) = chain.find_match(i, max_distance);
646
647 if len >= min_length {
648 if len < MAX_LAZY_MATCH_LEN && i + 1 < tokens.len() {
650 chain.update(i + 1);
651 already_updated = true;
652 let (dist_symbol2, len2) = chain.find_match(i + 1, max_distance);
653 if len2 > len {
654 i += 1;
657 already_updated = false;
658 len = len2;
659 dist_symbol = dist_symbol2;
660 out.push(tokens[i]);
661 }
662 }
663
664 let literal_cost = sym_cost[i + len] - sym_cost[i];
666 let lz77_len = len - min_length;
667
668 let lz77_cost = len_cost(lz77_len as u32)
670 + dist_cost(dist_symbol as u32)
671 + sce.add_symbol_cost(out.last().unwrap().context as usize);
672
673 if lz77_cost <= literal_cost {
674 let last_token = out.last_mut().unwrap();
676 last_token.value = lz77_len as u32;
677 last_token.is_lz77_length = true;
678
679 out.push(Token::new(lz77.distance_context, dist_symbol as u32));
680
681 bit_decrease += literal_cost - lz77_cost;
682 } else {
683 for j in 1..len {
685 out.push(tokens[i + j]);
686 }
687 }
688
689 if already_updated {
691 chain.update_range(i + 2, len - 2);
692 already_updated = false;
693 } else {
694 chain.update_range(i + 1, len - 1);
695 }
696 i += len - 1;
697 }
698 i += 1;
701 }
702
703 let threshold = total_symbols as f32 * 0.2 + 16.0;
705 #[cfg(feature = "debug-tokens")]
706 eprintln!(
707 "[LZ77-backref] bit_decrease={:.1}, threshold={:.1}, tokens: {} -> {}",
708 bit_decrease,
709 threshold,
710 total_symbols,
711 out.len()
712 );
713 if bit_decrease > threshold {
714 lz77.enabled = true;
715 Some((out, lz77))
716 } else {
717 None
718 }
719}
720
721pub fn apply_lz77_rle(
731 tokens: &[Token],
732 num_contexts: usize,
733 force_huffman: bool,
734) -> Option<(Vec<Token>, Lz77Params)> {
735 if tokens.is_empty() {
736 return None;
737 }
738
739 let mut lz77 = Lz77Params::new(num_contexts, force_huffman);
740
741 let sce = SymbolCostEstimator::new(num_contexts, force_huffman, tokens, &lz77);
744
745 let mut sym_cost = vec![0.0f32; tokens.len() + 1];
747 for (i, token) in tokens.iter().enumerate() {
748 let e = UintCoder::encode(token.value);
749 let cost = sce.symbol_cost(token.context as usize, e.token as usize) + e.nbits as f32;
750 sym_cost[i + 1] = sym_cost[i] + cost;
751 }
752
753 let mut out = Vec::with_capacity(tokens.len());
754 let mut bit_decrease: f32 = 0.0;
755 let total_symbols = tokens.len();
756
757 let mut i = 0;
758 while i < tokens.len() {
759 let mut num_to_copy = 0;
762 if i > 0 {
763 let prev_value = tokens[i - 1].value;
764 while i + num_to_copy < tokens.len() && tokens[i + num_to_copy].value == prev_value {
765 num_to_copy += 1;
766 }
767 }
768
769 if num_to_copy == 0 {
770 out.push(tokens[i]);
771 i += 1;
772 continue;
773 }
774
775 let literal_cost = sym_cost[i + num_to_copy] - sym_cost[i];
777
778 let lz77_cost = if num_to_copy >= lz77.min_length as usize {
780 let lz77_len = num_to_copy - lz77.min_length as usize;
781 ceil_log2_nonzero((lz77_len + 1) as u32) as f32 + 1.0
783 } else {
784 0.0
785 };
786
787 if num_to_copy < lz77.min_length as usize || literal_cost <= lz77_cost {
788 for j in 0..num_to_copy {
790 out.push(tokens[i + j]);
791 }
792 i += num_to_copy;
793 continue;
794 }
795
796 let lz77_len = (num_to_copy - lz77.min_length as usize) as u32;
798 out.push(Token::lz77_length(tokens[i].context, lz77_len));
799
800 out.push(Token::new(lz77.distance_context, 0));
803
804 bit_decrease += literal_cost - lz77_cost;
805 i += num_to_copy;
806 }
807
808 let threshold = total_symbols as f32 * 0.2 + 16.0;
810 #[cfg(feature = "debug-tokens")]
811 eprintln!(
812 "[LZ77] bit_decrease={:.1}, threshold={:.1}, tokens: {} -> {}",
813 bit_decrease,
814 threshold,
815 total_symbols,
816 out.len()
817 );
818 if bit_decrease > threshold {
819 lz77.enabled = true;
820 Some((out, lz77))
821 } else {
822 None
823 }
824}
825
826fn ceil_log2_nonzero(x: u32) -> u32 {
828 debug_assert!(x > 0);
829 let floor = 31 - x.leading_zeros();
830 if x.is_power_of_two() {
831 floor
832 } else {
833 floor + 1
834 }
835}
836
837#[derive(Debug, Clone, Copy, PartialEq, Eq, Default)]
839pub enum Lz77Method {
840 #[default]
843 Rle,
844 Greedy,
848}
849
850pub fn apply_lz77(
861 tokens: &[Token],
862 num_contexts: usize,
863 force_huffman: bool,
864 method: Lz77Method,
865 distance_multiplier: i32,
866) -> Option<(Vec<Token>, Lz77Params)> {
867 match method {
868 Lz77Method::Rle => apply_lz77_rle(tokens, num_contexts, force_huffman),
869 Lz77Method::Greedy => {
870 apply_lz77_backref(tokens, num_contexts, force_huffman, distance_multiplier)
871 }
872 }
873}
874
875#[allow(dead_code)] pub fn apply_lz77_best(
882 tokens: &[Token],
883 num_contexts: usize,
884 force_huffman: bool,
885 distance_multiplier: i32,
886) -> Option<(Vec<Token>, Lz77Params)> {
887 let rle_result = apply_lz77_rle(tokens, num_contexts, force_huffman);
888 let backref_result =
889 apply_lz77_backref(tokens, num_contexts, force_huffman, distance_multiplier);
890
891 match (&rle_result, &backref_result) {
892 (Some((rle_tokens, _)), Some((backref_tokens, _))) => {
893 if backref_tokens.len() <= rle_tokens.len() {
895 backref_result
896 } else {
897 rle_result
898 }
899 }
900 (Some(_), None) => rle_result,
901 (None, Some(_)) => backref_result,
902 (None, None) => None,
903 }
904}
905
906#[cfg(test)]
907mod tests {
908 use super::*;
909
910 #[test]
911 fn test_ceil_log2_nonzero() {
912 assert_eq!(ceil_log2_nonzero(1), 0);
913 assert_eq!(ceil_log2_nonzero(2), 1);
914 assert_eq!(ceil_log2_nonzero(3), 2);
915 assert_eq!(ceil_log2_nonzero(4), 2);
916 assert_eq!(ceil_log2_nonzero(5), 3);
917 assert_eq!(ceil_log2_nonzero(8), 3);
918 assert_eq!(ceil_log2_nonzero(9), 4);
919 }
920
921 #[test]
922 fn test_no_rle_on_short_stream() {
923 let tokens = vec![Token::new(0, 5), Token::new(0, 5), Token::new(0, 5)];
925 assert!(apply_lz77_rle(&tokens, 1, false).is_none());
926 }
927
928 #[test]
929 fn test_rle_on_long_run() {
930 let mut tokens = Vec::new();
932 tokens.push(Token::new(0, 5));
934 for _ in 0..200 {
935 tokens.push(Token::new(0, 5));
936 }
937
938 let result = apply_lz77_rle(&tokens, 1, false);
939 if let Some((lz77_tokens, params)) = result {
940 assert!(params.enabled);
941 assert!(lz77_tokens.len() < tokens.len());
943 assert!(lz77_tokens.iter().any(|t| t.is_lz77_length));
945 }
946 }
948
949 #[test]
950 fn test_rle_preserves_non_runs() {
951 let mut tokens = Vec::new();
953 for i in 0..10 {
955 tokens.push(Token::new(0, i));
956 }
957 for _ in 0..100 {
959 tokens.push(Token::new(0, 42));
960 }
961 for i in 0..10 {
963 tokens.push(Token::new(0, i + 100));
964 }
965
966 if let Some((lz77_tokens, params)) = apply_lz77_rle(&tokens, 1, false) {
967 assert!(params.enabled);
968 assert!(lz77_tokens.len() < tokens.len());
969 assert_eq!(lz77_tokens[0].value, 0);
971 assert!(!lz77_tokens[0].is_lz77_length);
972 }
973 }
974
975 #[test]
976 fn test_empty_stream() {
977 assert!(apply_lz77_rle(&[], 1, false).is_none());
978 }
979
980 #[test]
983 fn test_backref_empty_stream() {
984 assert!(apply_lz77_backref(&[], 1, false, 0).is_none());
985 }
986
987 #[test]
988 fn test_backref_short_stream() {
989 let tokens = vec![Token::new(0, 5), Token::new(0, 5), Token::new(0, 5)];
991 assert!(apply_lz77_backref(&tokens, 1, false, 0).is_none());
992 }
993
994 #[test]
995 fn test_backref_on_repeating_pattern() {
996 let mut tokens = Vec::new();
999 for _ in 0..100 {
1000 tokens.push(Token::new(0, 10));
1001 tokens.push(Token::new(0, 20));
1002 tokens.push(Token::new(0, 30));
1003 }
1004
1005 let result = apply_lz77_backref(&tokens, 1, false, 0);
1006 if let Some((lz77_tokens, params)) = result {
1007 assert!(params.enabled);
1008 assert!(
1010 lz77_tokens.len() < tokens.len(),
1011 "backref should compress pattern: {} vs {}",
1012 lz77_tokens.len(),
1013 tokens.len()
1014 );
1015 assert!(lz77_tokens.iter().any(|t| t.is_lz77_length));
1017 }
1018 }
1019
1020 #[test]
1021 fn test_backref_finds_longer_matches_than_rle() {
1022 let mut tokens = Vec::new();
1025 for _ in 0..50 {
1026 for j in 1..=5 {
1027 tokens.push(Token::new(0, j));
1028 }
1029 }
1030
1031 let rle_result = apply_lz77_rle(&tokens, 1, false);
1032 let backref_result = apply_lz77_backref(&tokens, 1, false, 0);
1033
1034 match (&rle_result, &backref_result) {
1037 (None, Some((backref_tokens, _))) => {
1038 assert!(backref_tokens.len() < tokens.len());
1040 }
1041 (Some((rle_tokens, _)), Some((backref_tokens, _))) => {
1042 assert!(backref_tokens.len() <= rle_tokens.len());
1044 }
1045 _ => {
1046 }
1048 }
1049 }
1050
1051 #[test]
1052 fn test_backref_with_distance_multiplier() {
1053 let mut tokens = Vec::new();
1056 let image_width = 64;
1057
1058 for _row in 0..20 {
1060 for col in 0..image_width {
1061 tokens.push(Token::new(0, (col % 16) as u32));
1063 }
1064 }
1065
1066 let _result_no_mult = apply_lz77_backref(&tokens, 1, false, 0);
1067 let result_with_mult = apply_lz77_backref(&tokens, 1, false, image_width);
1068
1069 if let Some((tokens_mult, params)) = result_with_mult {
1072 assert!(params.enabled);
1073 assert!(tokens_mult.len() < tokens.len());
1074 }
1075 }
1076
1077 #[test]
1078 fn test_special_distance() {
1079 assert_eq!(special_distance(0, 64), 64);
1082 assert_eq!(special_distance(1, 64), 1);
1084 assert_eq!(special_distance(2, 64), 65);
1086 assert_eq!(special_distance(3, 64), 63);
1088 }
1089
1090 #[test]
1091 fn test_len_cost() {
1092 for len in 0..1000 {
1094 let cost = len_cost(len);
1095 assert!(cost >= 0.0, "len_cost({}) should be non-negative", len);
1096 assert!(cost < 100.0, "len_cost({}) should be reasonable", len);
1097 }
1098 }
1099
1100 #[test]
1101 fn test_dist_cost() {
1102 for dist in 0..10000 {
1104 let cost = dist_cost(dist);
1105 assert!(cost >= 0.0, "dist_cost({}) should be non-negative", dist);
1106 assert!(cost < 100.0, "dist_cost({}) should be reasonable", dist);
1107 }
1108 }
1109
1110 #[test]
1111 fn test_apply_lz77_method_enum() {
1112 let mut tokens = Vec::new();
1113 tokens.push(Token::new(0, 5));
1114 for _ in 0..200 {
1115 tokens.push(Token::new(0, 5));
1116 }
1117
1118 let rle_result = apply_lz77(&tokens, 1, false, Lz77Method::Rle, 0);
1120 if let Some((_, params)) = &rle_result {
1121 assert!(params.enabled);
1122 }
1123
1124 let greedy_result = apply_lz77(&tokens, 1, false, Lz77Method::Greedy, 0);
1126 if let Some((_, params)) = &greedy_result {
1127 assert!(params.enabled);
1128 }
1129 }
1130
1131 #[test]
1132 fn test_apply_lz77_best() {
1133 let mut tokens = Vec::new();
1135 for _ in 0..50 {
1136 for j in 1..=10 {
1137 tokens.push(Token::new(0, j));
1138 }
1139 }
1140
1141 let best_result = apply_lz77_best(&tokens, 1, false, 0);
1142 if let Some((best_tokens, params)) = best_result {
1144 assert!(params.enabled);
1145 assert!(best_tokens.len() < tokens.len());
1146 }
1147 }
1148
1149 #[test]
1150 fn test_hash_chain_basic() {
1151 let tokens = vec![
1153 Token::new(0, 10),
1154 Token::new(0, 20),
1155 Token::new(0, 30),
1156 Token::new(0, 40), Token::new(0, 10),
1158 Token::new(0, 20),
1159 Token::new(0, 30), ];
1161
1162 let mut chain = HashChain::new(&tokens, 16, 3, 100, 0);
1163 for i in 0..tokens.len() {
1165 chain.update(i);
1166 }
1167
1168 let (dist_symbol, len) = chain.find_match(4, 10);
1170 assert!(len >= 3, "should find match of length >= 3, got {}", len);
1171 assert_eq!(dist_symbol, 3, "distance symbol for dist=4 should be 3");
1175 }
1176}