Skip to main content

jxl_encoder/entropy_coding/
lz77.rs

1// Copyright (c) Imazen LLC and the JPEG XL Project Authors.
2// Algorithms and constants derived from libjxl (BSD-3-Clause).
3// Licensed under AGPL-3.0-or-later. Commercial licenses at https://www.imazen.io/pricing
4
5//! LZ77 backward references for entropy coding.
6//!
7//! Implements both RLE-only and full backward-reference LZ77 methods from libjxl.
8//! - `apply_lz77_rle`: RLE-only (consecutive identical values)
9//! - `apply_lz77_backref`: Full backward references with hash chains (greedy matching)
10//!
11//! The backward-reference method uses hash chains to find matches at arbitrary
12//! distances within a sliding window, providing 1-3% compression improvement on
13//! photographic content compared to RLE-only.
14
15use hashbrown::HashMap;
16
17use super::token::{Lz77UintCoder, Token, UintCoder};
18use crate::bit_writer::BitWriter;
19use crate::error::Result;
20
21/// Maximum window size for LZ77 matching (1MB, matches libjxl).
22const WINDOW_SIZE: usize = 1 << 20;
23
24/// Number of special distance codes from WebP lossless.
25const NUM_SPECIAL_DISTANCES: usize = 120;
26
27/// Special distance codes from WebP lossless.
28/// Each entry is [offset, multiplier] where distance = offset + multiplier * image_width.
29/// These encode common 2D patterns (horizontal, vertical, diagonal) compactly.
30#[rustfmt::skip]
31const SPECIAL_DISTANCES: [[i8; 2]; NUM_SPECIAL_DISTANCES] = [
32    [0, 1],  [1, 0],  [1, 1],  [-1, 1], [0, 2],  [2, 0],  [1, 2],  [-1, 2],
33    [2, 1],  [-2, 1], [2, 2],  [-2, 2], [0, 3],  [3, 0],  [1, 3],  [-1, 3],
34    [3, 1],  [-3, 1], [2, 3],  [-2, 3], [3, 2],  [-3, 2], [0, 4],  [4, 0],
35    [1, 4],  [-1, 4], [4, 1],  [-4, 1], [3, 3],  [-3, 3], [2, 4],  [-2, 4],
36    [4, 2],  [-4, 2], [0, 5],  [3, 4],  [-3, 4], [4, 3],  [-4, 3], [5, 0],
37    [1, 5],  [-1, 5], [5, 1],  [-5, 1], [2, 5],  [-2, 5], [5, 2],  [-5, 2],
38    [4, 4],  [-4, 4], [3, 5],  [-3, 5], [5, 3],  [-5, 3], [0, 6],  [6, 0],
39    [1, 6],  [-1, 6], [6, 1],  [-6, 1], [2, 6],  [-2, 6], [6, 2],  [-6, 2],
40    [4, 5],  [-4, 5], [5, 4],  [-5, 4], [3, 6],  [-3, 6], [6, 3],  [-6, 3],
41    [0, 7],  [7, 0],  [1, 7],  [-1, 7], [5, 5],  [-5, 5], [7, 1],  [-7, 1],
42    [4, 6],  [-4, 6], [6, 4],  [-6, 4], [2, 7],  [-2, 7], [7, 2],  [-7, 2],
43    [3, 7],  [-3, 7], [7, 3],  [-7, 3], [5, 6],  [-5, 6], [6, 5],  [-6, 5],
44    [8, 0],  [4, 7],  [-4, 7], [7, 4],  [-7, 4], [8, 1],  [8, 2],  [6, 6],
45    [-6, 6], [8, 3],  [5, 7],  [-5, 7], [7, 5],  [-7, 5], [8, 4],  [6, 7],
46    [-6, 7], [7, 6],  [-7, 6], [8, 5],  [7, 7],  [-7, 7], [8, 6],  [8, 7],
47];
48
49/// Compute special distance from code index and distance multiplier (image width).
50#[inline]
51fn special_distance(index: usize, multiplier: i32) -> i32 {
52    SPECIAL_DISTANCES[index][0] as i32 + multiplier * SPECIAL_DISTANCES[index][1] as i32
53}
54
55/// Empirical cost table for LZ77 length encoding (from libjxl).
56/// Indexed by token value from HybridUintConfig(1, 0, 0).
57#[rustfmt::skip]
58#[allow(clippy::excessive_precision)]
59const LEN_COST_TABLE: [f32; 17] = [
60    2.797667318563126,  3.213177690381199,  2.5706009246743737,
61    2.408392498667534,  2.829649191872326,  3.3923087753324577,
62    4.029267451554331,  4.415576699706408,  4.509357574741465,
63    9.21481543803004,   10.020590190114898, 11.858671627804766,
64    12.45853300490526,  11.713105831990857, 12.561996324849314,
65    13.775477692278367, 13.174027068768641,
66];
67
68/// Empirical cost table for LZ77 distance encoding (from libjxl).
69/// Indexed by token value from HybridUintConfig(7, 0, 0).
70#[rustfmt::skip]
71#[allow(clippy::excessive_precision)]
72const DIST_COST_TABLE: [f32; 139] = [
73    6.368282626312716,  5.680793277090298,  8.347404197105247,
74    7.641619201599141,  6.914328374119438,  7.959808291537444,
75    8.70023120759855,   8.71378518934703,   9.379132523982769,
76    9.110472749092708,  9.159029569270908,  9.430936766731973,
77    7.278284055315169,  7.8278514904267755, 10.026641158289236,
78    9.976049229827066,  9.64351607048908,   9.563403863480442,
79    10.171474111762747, 10.45950155077234,  9.994813912104219,
80    10.322524683741156, 8.465808729388186,  8.756254166066853,
81    10.160930174662234, 10.247329273413435, 10.04090403724809,
82    10.129398517544082, 9.342311691539546,  9.07608009102374,
83    10.104799540677513, 10.378079384990906, 10.165828974075072,
84    10.337595322341553, 7.940557464567944,  10.575665823319431,
85    11.023344321751955, 10.736144698831827, 11.118277044595054,
86    7.468468230648442,  10.738305230932939, 10.906980780216568,
87    10.163468216353817, 10.17805759656433,  11.167283670483565,
88    11.147050200274544, 10.517921919244333, 10.651764778156886,
89    10.17074446448919,  11.217636876224745, 11.261630721139484,
90    11.403140815247259, 10.892472096873417, 11.1859607804481,
91    8.017346947551262,  7.895143720278828,  11.036577113822025,
92    11.170562110315794, 10.326988722591086, 10.40872184751056,
93    11.213498225466386, 11.30580635516863,  10.672272515665442,
94    10.768069466228063, 11.145257364153565, 11.64668307145549,
95    10.593156194627339, 11.207499484844943, 10.767517766396908,
96    10.826629811407042, 10.737764794499988, 10.6200448518045,
97    10.191315385198092, 8.468384171390085,  11.731295299170432,
98    11.824619886654398, 10.41518844301179,  10.16310536548649,
99    10.539423685097576, 10.495136599328031, 10.469112847728267,
100    11.72057686174922,  10.910326337834674, 11.378921834673758,
101    11.847759036098536, 11.92071647623854,  10.810628276345282,
102    11.008601085273893, 11.910326337834674, 11.949212023423133,
103    11.298614839104337, 11.611603659010392, 10.472930394619985,
104    11.835564720850282, 11.523267392285337, 12.01055816679611,
105    8.413029688994023,  11.895784139536406, 11.984679534970505,
106    11.220654278717394, 11.716311684833672, 10.61036646226114,
107    10.89849965960364,  10.203762898863669, 10.997560826267238,
108    11.484217379438984, 11.792836176993665, 12.24310468755171,
109    11.464858097919262, 12.212747017409377, 11.425595666074955,
110    11.572048533398757, 12.742093965163013, 11.381874288645637,
111    12.191870445817015, 11.683156920035426, 11.152442115262197,
112    11.90303691580457,  11.653292787169159, 11.938615382266098,
113    16.970641701570223, 16.853602280380002, 17.26240782594733,
114    16.644655390108507, 17.14310889757499,  16.910935455445955,
115    17.505678976959697, 17.213498225466388,
116    // Entries 128-138: special distance code costs (from libjxl enc_lz77.cc:442-446).
117    // These have dramatically lower costs (2.4-9.7) vs the preceding entries (~17),
118    // because special distance codes encode distances as multiples of image width
119    // (useful for vertical matches in image data).
120    2.4162310293553024, 3.494587244462329,  3.5258600986408344,
121    3.4959806589517095, 3.098390886949687,  3.343454654302911,
122    3.588847442290287,  4.14614790111827,   5.152948641990529,
123    7.433696808092598,  9.716311684833672,
124];
125
126/// Empirical cost for LZ77 length encoding.
127fn len_cost(len: u32) -> f32 {
128    // HybridUintConfig(1, 0, 0): token = 1 + floor_log2(len) for len >= 1
129    let (tok, nbits) = if len == 0 {
130        (0u32, 0u32)
131    } else {
132        let n = 31 - len.leading_zeros();
133        (1 + n, n)
134    };
135    let table_size = LEN_COST_TABLE.len();
136    let tok_idx = (tok as usize).min(table_size - 1);
137    LEN_COST_TABLE[tok_idx] + nbits as f32
138}
139
140/// Empirical cost for LZ77 distance encoding.
141fn dist_cost(dist: u32) -> f32 {
142    // HybridUintConfig(7, 0, 0): different split point
143    let (tok, nbits) = hybrid_uint_encode_7_0_0(dist);
144    let table_size = DIST_COST_TABLE.len();
145    let tok_idx = (tok as usize).min(table_size - 1);
146    DIST_COST_TABLE[tok_idx] + nbits as f32
147}
148
149/// HybridUint encoding with config (7, 0, 0) for distance symbols.
150fn hybrid_uint_encode_7_0_0(value: u32) -> (u32, u32) {
151    // split = 7, msb_in_token = 0, lsb_in_token = 0
152    // Values 0-6: direct encoding
153    // Values >= 7: floor_log2 encoding
154    if value < 7 {
155        (value, 0)
156    } else {
157        let n = 31 - value.leading_zeros();
158        let tok = 7 + n - 3; // Offset for values >= 7
159        (tok, n)
160    }
161}
162
163/// LZ77 parameters serialized in the entropy code header.
164#[derive(Debug, Clone)]
165pub struct Lz77Params {
166    pub enabled: bool,
167    /// Symbols >= min_symbol are LZ77 length tokens.
168    /// ANS: 224, Huffman: 512.
169    pub min_symbol: u32,
170    /// Minimum run length to encode as LZ77. Default: 3.
171    pub min_length: u32,
172    /// Context index for distance tokens (= num_contexts before LZ77).
173    pub distance_context: u32,
174}
175
176impl Lz77Params {
177    pub fn new(num_contexts: usize, force_huffman: bool) -> Self {
178        Self {
179            enabled: false,
180            min_symbol: if force_huffman { 512 } else { 224 },
181            min_length: 3,
182            distance_context: num_contexts as u32,
183        }
184    }
185}
186
187/// Write LZ77 header to the bitstream.
188///
189/// If `lz77` is `Some`, writes `enabled=1` followed by min_symbol, min_length,
190/// and length_uint_config. If `None`, writes `enabled=0`.
191///
192/// JXL spec format:
193/// ```text
194/// Bool(enabled)
195/// if enabled:
196///   U32(Val(224), Val(512), Val(4096), BitsOffset(15,8))  // min_symbol
197///   U32(Val(3), Val(4), BitsOffset(2,5), BitsOffset(8,9)) // min_length
198///   EncodeUintConfig(length_uint_config, log_alpha_size=8)
199/// ```
200pub fn write_lz77_header(lz77: Option<&Lz77Params>, writer: &mut BitWriter) -> Result<()> {
201    if let Some(params) = lz77 {
202        writer.write(1, 1)?; // lz77 enabled
203
204        // min_symbol: U32(Val(224), Val(512), Val(4096), BitsOffset(15,8))
205        match params.min_symbol {
206            224 => writer.write(2, 0)?,  // selector 0 = Val(224)
207            512 => writer.write(2, 1)?,  // selector 1 = Val(512)
208            4096 => writer.write(2, 2)?, // selector 2 = Val(4096)
209            v => {
210                writer.write(2, 3)?; // selector 3 = BitsOffset(15, 8)
211                writer.write(15, (v - 8) as u64)?;
212            }
213        }
214
215        // min_length: U32(Val(3), Val(4), BitsOffset(2,5), BitsOffset(8,9))
216        match params.min_length {
217            3 => writer.write(2, 0)?, // selector 0 = Val(3)
218            4 => writer.write(2, 1)?, // selector 1 = Val(4)
219            v @ 5..=8 => {
220                writer.write(2, 2)?; // selector 2 = BitsOffset(2, 5)
221                writer.write(2, (v - 5) as u64)?;
222            }
223            v => {
224                writer.write(2, 3)?; // selector 3 = BitsOffset(8, 9)
225                writer.write(8, (v - 9) as u64)?;
226            }
227        }
228
229        // length_uint_config: HybridUintConfig(0, 0, 0)
230        // split_exponent=0 → 4 bits, msb/lsb need 0 bits each
231        writer.write(4, 0)?;
232    } else {
233        writer.write(1, 0)?; // no lz77
234    }
235    Ok(())
236}
237
238/// Estimate per-symbol bit cost from histograms, matching libjxl's SymbolCostEstimator.
239struct SymbolCostEstimator {
240    /// Flat array: bits[ctx * max_alphabet_size + sym]
241    bits: Vec<f32>,
242    max_alphabet_size: usize,
243}
244
245impl SymbolCostEstimator {
246    fn new(num_contexts: usize, force_huffman: bool, tokens: &[Token], lz77: &Lz77Params) -> Self {
247        const ANS_LOG_TAB_SIZE: f32 = 12.0;
248
249        // Build per-context histograms from the (possibly LZ77-transformed) tokens.
250        let mut counts: Vec<Vec<u32>> = vec![vec![]; num_contexts];
251        let mut total_counts = vec![0u32; num_contexts];
252
253        for token in tokens {
254            let (tok, _nbits) = if token.is_lz77_length() {
255                let e = Lz77UintCoder::encode(token.value);
256                (e.token + lz77.min_symbol, e.nbits)
257            } else {
258                let e = UintCoder::encode(token.value);
259                (e.token, e.nbits)
260            };
261            let ctx = token.context() as usize;
262            if ctx < num_contexts {
263                let sym = tok as usize;
264                if sym >= counts[ctx].len() {
265                    counts[ctx].resize(sym + 1, 0);
266                }
267                counts[ctx][sym] += 1;
268                total_counts[ctx] += 1;
269            }
270        }
271
272        let max_alphabet_size = counts.iter().map(|c| c.len()).max().unwrap_or(0);
273        let mut bits = vec![0.0f32; num_contexts * max_alphabet_size];
274
275        for ctx in 0..num_contexts {
276            let total = total_counts[ctx];
277            if total == 0 {
278                continue;
279            }
280            let inv_total = 1.0 / (total as f32 + 1e-8);
281            for sym in 0..counts[ctx].len() {
282                let cnt = counts[ctx][sym];
283                let cost = if cnt != 0 && cnt != total {
284                    let p = cnt as f32 * inv_total;
285                    let c = -jxl_simd::fast_log2f(p);
286                    if force_huffman { c.ceil() } else { c }
287                } else if cnt == 0 {
288                    ANS_LOG_TAB_SIZE // Highest possible cost
289                } else {
290                    0.0 // Single symbol, zero cost
291                };
292                bits[ctx * max_alphabet_size + sym] = cost;
293            }
294        }
295
296        Self {
297            bits,
298            max_alphabet_size,
299        }
300    }
301
302    #[inline]
303    fn symbol_cost(&self, ctx: usize, sym: usize) -> f32 {
304        if sym < self.max_alphabet_size {
305            self.bits[ctx * self.max_alphabet_size + sym]
306        } else {
307            12.0 // ANS_LOG_TAB_SIZE as fallback
308        }
309    }
310
311    /// Cost of adding an LZ77 symbol to a context (penalty for low-entropy contexts).
312    fn add_symbol_cost(&self, ctx: usize) -> f32 {
313        // Compute average cost per symbol in this context
314        let mut total_cost = 0.0f32;
315        let mut total_count = 0u32;
316        for sym in 0..self.max_alphabet_size {
317            let cost = self.bits[ctx * self.max_alphabet_size + sym];
318            if cost < 12.0 {
319                // Only count symbols that exist in the histogram
320                total_cost += cost;
321                total_count += 1;
322            }
323        }
324        if total_count == 0 {
325            return 0.0;
326        }
327        // Higher penalty for contexts with low per-symbol entropy
328        (6.0 - total_cost / total_count as f32).max(0.0)
329    }
330
331    /// Cost of encoding an LZ77 length token using histogram-based estimation.
332    fn len_cost(&self, ctx: usize, len: u32, lz77: &Lz77Params) -> f32 {
333        // HybridUintConfig(1, 0, 0) for LZ77 length
334        let (tok, nbits) = if len == 0 {
335            (0u32, 0u32)
336        } else {
337            let n = 31 - len.leading_zeros();
338            (1 + n, n)
339        };
340        let sym = tok + lz77.min_symbol;
341        nbits as f32 + self.symbol_cost(ctx, sym as usize)
342    }
343
344    /// Cost of encoding an LZ77 distance token using histogram-based estimation.
345    fn dist_cost_sce(&self, dist_symbol: u32, lz77: &Lz77Params) -> f32 {
346        let (tok, nbits) = UintCoder::encode(dist_symbol).into();
347        nbits as f32 + self.symbol_cost(lz77.distance_context as usize, tok as usize)
348    }
349}
350
351/// Hash chain for LZ77 match finding.
352///
353/// Uses a sliding window and hash table to efficiently find matching sequences.
354/// Matches libjxl's HashChain implementation in enc_lz77.cc.
355struct HashChain {
356    /// Token values (we only hash on value, not context)
357    data: Vec<u32>,
358    /// Size of token stream
359    size: usize,
360    /// Window size (power of 2)
361    window_size: usize,
362    /// Window mask (window_size - 1)
363    window_mask: usize,
364    /// Minimum match length
365    min_length: usize,
366    /// Maximum match length
367    max_length: usize,
368
369    // Hash table parameters
370    #[allow(dead_code)] // Stored for debugging/reference
371    hash_num_values: usize,
372    hash_mask: usize,
373    hash_shift: u32,
374
375    /// Head of hash chain for each hash value (-1 if empty)
376    head: Vec<i32>,
377    /// Hash chain: next position with same hash
378    chain: Vec<u32>,
379    /// Hash value at each window position (-1 if invalid)
380    val: Vec<i32>,
381
382    // Zero-run optimization
383    /// Head of zero-run chain for each run length
384    headz: Vec<i32>,
385    /// Zero-run chain
386    chainz: Vec<u32>,
387    /// Number of consecutive zeros starting at each position
388    zeros: Vec<u32>,
389    /// Current zero count
390    numzeros: u32,
391
392    /// Map from actual distance to special distance symbol
393    special_dist_table: HashMap<i32, usize>,
394    /// Number of special distances (0 if no multiplier, 120 otherwise)
395    num_special_distances: usize,
396
397    /// Maximum chain length to traverse (limits search time)
398    max_chain_length: u32,
399}
400
401impl HashChain {
402    fn new(
403        tokens: &[Token],
404        window_size: usize,
405        min_length: usize,
406        max_length: usize,
407        distance_multiplier: i32,
408    ) -> Self {
409        let size = tokens.len();
410
411        // Extract just the values
412        let data: Vec<u32> = tokens.iter().map(|t| t.value).collect();
413
414        // Hash table setup
415        let hash_num_values = 32768usize;
416        let hash_mask = hash_num_values - 1;
417        let hash_shift = 5u32;
418
419        let head = vec![-1i32; hash_num_values];
420        let chain: Vec<u32> = (0..window_size as u32).collect(); // Self-reference indicates uninitialized
421        let val = vec![-1i32; window_size];
422
423        // Zero-run optimization
424        let headz = vec![-1i32; window_size + 1];
425        let chainz: Vec<u32> = (0..window_size as u32).collect();
426        let zeros = vec![0u32; window_size];
427
428        // Build special distance table
429        let mut special_dist_table = HashMap::new();
430        let num_special_distances = if distance_multiplier != 0 {
431            // Count down so smallest code wins on ties
432            for i in (0..NUM_SPECIAL_DISTANCES).rev() {
433                let dist = special_distance(i, distance_multiplier);
434                if dist > 0 {
435                    special_dist_table.insert(dist, i);
436                }
437            }
438            NUM_SPECIAL_DISTANCES
439        } else {
440            0
441        };
442
443        Self {
444            data,
445            size,
446            window_size,
447            window_mask: window_size - 1,
448            min_length,
449            max_length,
450            hash_num_values,
451            hash_mask,
452            hash_shift,
453            head,
454            chain,
455            val,
456            headz,
457            chainz,
458            zeros,
459            numzeros: 0,
460            special_dist_table,
461            num_special_distances,
462            max_chain_length: 256,
463        }
464    }
465
466    /// Compute hash of 3 consecutive values starting at pos.
467    fn get_hash(&self, pos: usize) -> u32 {
468        if pos + 2 >= self.size {
469            return 0;
470        }
471        let mut result = 0u32;
472        result ^= self.data[pos] & 0xFFFF;
473        result ^= (self.data[pos + 1] & 0xFFFF) << self.hash_shift;
474        result ^= (self.data[pos + 2] & 0xFFFF) << (self.hash_shift * 2);
475        result & self.hash_mask as u32
476    }
477
478    /// Count consecutive zeros starting at pos.
479    fn count_zeros(&self, pos: usize, prev_zeros: u32) -> u32 {
480        let end = (pos + self.window_size).min(self.size);
481        if prev_zeros > 0 {
482            if prev_zeros >= self.window_mask as u32
483                && self.data[end - 1] == 0
484                && end == pos + self.window_size
485            {
486                return prev_zeros;
487            } else {
488                return prev_zeros - 1;
489            }
490        }
491        let mut num = 0u32;
492        while pos + (num as usize) < end && self.data[pos + (num as usize)] == 0 {
493            num += 1;
494        }
495        num
496    }
497
498    /// Update hash chain with position pos.
499    fn update(&mut self, pos: usize) {
500        let hashval = self.get_hash(pos);
501        let wpos = pos & self.window_mask;
502
503        self.val[wpos] = hashval as i32;
504        if self.head[hashval as usize] != -1 {
505            self.chain[wpos] = self.head[hashval as usize] as u32;
506        }
507        self.head[hashval as usize] = wpos as i32;
508
509        // Update zero count
510        if pos > 0 && self.data[pos] != self.data[pos - 1] {
511            self.numzeros = 0;
512        }
513        self.numzeros = self.count_zeros(pos, self.numzeros);
514
515        self.zeros[wpos] = self.numzeros;
516        if self.headz[self.numzeros as usize] != -1 {
517            self.chainz[wpos] = self.headz[self.numzeros as usize] as u32;
518        }
519        self.headz[self.numzeros as usize] = wpos as i32;
520    }
521
522    /// Update hash chain for multiple positions.
523    fn update_range(&mut self, pos: usize, len: usize) {
524        for i in 0..len {
525            self.update(pos + i);
526        }
527    }
528
529    /// Find best match at position pos.
530    /// Returns (distance_symbol, match_length).
531    fn find_match(&self, pos: usize, max_dist: usize) -> (usize, usize) {
532        let mut best_dist_symbol = 0usize;
533        let mut best_len = 1usize;
534
535        self.find_matches(pos, max_dist, |len, dist_symbol| {
536            if len > best_len || (len == best_len && dist_symbol < best_dist_symbol) {
537                best_len = len;
538                best_dist_symbol = dist_symbol;
539            }
540        });
541
542        (best_dist_symbol, best_len)
543    }
544
545    /// Find all matches at position pos, calling callback for each.
546    fn find_matches<F>(&self, pos: usize, max_dist: usize, mut found_match: F)
547    where
548        F: FnMut(usize, usize),
549    {
550        let wpos = pos & self.window_mask;
551        let hashval = self.get_hash(pos);
552        let mut hashpos = self.chain[wpos];
553
554        let mut prev_dist = 0i32;
555        let end = (pos + self.max_length).min(self.size);
556        let mut chain_length = 0u32;
557        let mut best_len = 0usize;
558
559        loop {
560            // Compute distance from current position to hash chain position
561            let dist = if hashpos as usize <= wpos {
562                wpos - hashpos as usize
563            } else {
564                wpos + self.window_mask + 1 - hashpos as usize
565            };
566
567            if (dist as i32) < prev_dist {
568                break;
569            }
570            prev_dist = dist as i32;
571
572            if dist > 0 && dist <= max_dist {
573                // Compare sequences
574                let mut i = pos;
575                let mut j = pos - dist;
576
577                // Zero-run optimization: skip known zeros
578                if self.numzeros > 3 {
579                    let r =
580                        ((self.numzeros - 1) as usize).min(self.zeros[hashpos as usize] as usize);
581                    let skip = if i + r >= end { end - i - 1 } else { r };
582                    i += skip;
583                    j += skip;
584                }
585
586                // Extend match
587                while i < end && self.data[i] == self.data[j] {
588                    i += 1;
589                    j += 1;
590                }
591
592                let len = i - pos;
593
594                // Accept match if long enough and potentially better
595                if len >= self.min_length && len + 2 >= best_len {
596                    let dist_symbol =
597                        if let Some(&sym) = self.special_dist_table.get(&(dist as i32)) {
598                            sym
599                        } else {
600                            self.num_special_distances + dist - 1
601                        };
602                    found_match(len, dist_symbol);
603                    if len > best_len {
604                        best_len = len;
605                    }
606                }
607            }
608
609            chain_length += 1;
610            if chain_length >= self.max_chain_length {
611                break;
612            }
613
614            // Follow chain
615            if self.numzeros >= 3 && best_len > self.numzeros as usize {
616                // Use zero-run chain for efficiency
617                if hashpos == self.chainz[hashpos as usize] {
618                    break;
619                }
620                hashpos = self.chainz[hashpos as usize];
621                if self.zeros[hashpos as usize] != self.numzeros {
622                    break;
623                }
624            } else {
625                // Use regular hash chain
626                if hashpos == self.chain[hashpos as usize] {
627                    break;
628                }
629                hashpos = self.chain[hashpos as usize];
630                if self.val[hashpos as usize] != hashval as i32 {
631                    // Outdated hash value
632                    break;
633                }
634            }
635        }
636    }
637}
638
639/// Apply greedy LZ77 with backward references using hash chains.
640///
641/// This implements libjxl's `ApplyLZ77_LZ77` algorithm which uses hash chains
642/// to find matching sequences at arbitrary distances within a sliding window.
643/// Includes lazy matching to find longer matches at the next position.
644///
645/// Returns `Some((transformed_tokens, params))` if LZ77 is beneficial,
646/// or `None` if the savings are insufficient.
647pub fn apply_lz77_backref(
648    tokens: &[Token],
649    num_contexts: usize,
650    force_huffman: bool,
651    distance_multiplier: i32,
652) -> Option<(Vec<Token>, Lz77Params)> {
653    if tokens.is_empty() {
654        return None;
655    }
656
657    let mut lz77 = Lz77Params::new(num_contexts, force_huffman);
658
659    // Build cost estimator from original tokens
660    let sce = SymbolCostEstimator::new(num_contexts, force_huffman, tokens, &lz77);
661
662    // Compute cumulative bit costs for original stream
663    let mut sym_cost = vec![0.0f32; tokens.len() + 1];
664    for (i, token) in tokens.iter().enumerate() {
665        let e = UintCoder::encode(token.value);
666        let cost = sce.symbol_cost(token.context() as usize, e.token as usize) + e.nbits as f32;
667        sym_cost[i + 1] = sym_cost[i] + cost;
668    }
669
670    let mut out = Vec::with_capacity(tokens.len());
671    let mut bit_decrease: f32 = 0.0;
672    let total_symbols = tokens.len();
673
674    let max_distance = tokens.len();
675    let min_length = lz77.min_length as usize;
676    let max_length = tokens.len();
677
678    // Use next power of two as window size
679    let mut window_size = 1usize;
680    while window_size < max_distance && window_size < WINDOW_SIZE {
681        window_size <<= 1;
682    }
683
684    let mut chain = HashChain::new(
685        tokens,
686        window_size,
687        min_length,
688        max_length,
689        distance_multiplier,
690    );
691
692    const MAX_LAZY_MATCH_LEN: usize = 256;
693    let mut already_updated = false;
694
695    let mut i = 0usize;
696    while i < tokens.len() {
697        out.push(tokens[i]);
698
699        if !already_updated {
700            chain.update(i);
701        }
702        already_updated = false;
703
704        let (mut dist_symbol, mut len) = chain.find_match(i, max_distance);
705
706        if len >= min_length {
707            // Try lazy matching: check if next position has a longer match
708            if len < MAX_LAZY_MATCH_LEN && i + 1 < tokens.len() {
709                chain.update(i + 1);
710                already_updated = true;
711                let (dist_symbol2, len2) = chain.find_match(i + 1, max_distance);
712                if len2 > len {
713                    // Use lazy match: emit literal for current position,
714                    // then use match starting at next position
715                    i += 1;
716                    already_updated = false;
717                    len = len2;
718                    dist_symbol = dist_symbol2;
719                    out.push(tokens[i]);
720                }
721            }
722
723            // Compute costs
724            let literal_cost = sym_cost[i + len] - sym_cost[i];
725            let lz77_len = len - min_length;
726
727            // Use empirical cost tables for LZ77 encoding
728            let lz77_cost = len_cost(lz77_len as u32)
729                + dist_cost(dist_symbol as u32)
730                + sce.add_symbol_cost(out.last().unwrap().context() as usize);
731
732            if lz77_cost <= literal_cost {
733                // Emit LZ77 match
734                let last_token = out.last_mut().unwrap();
735                last_token.value = lz77_len as u32;
736                last_token.set_lz77_length(true);
737
738                out.push(Token::new(lz77.distance_context, dist_symbol as u32));
739
740                bit_decrease += literal_cost - lz77_cost;
741            } else {
742                // LZ77 not beneficial, emit literals
743                for j in 1..len {
744                    out.push(tokens[i + j]);
745                }
746            }
747
748            // Update hash chain for matched positions
749            if already_updated {
750                chain.update_range(i + 2, len - 2);
751                already_updated = false;
752            } else {
753                chain.update_range(i + 1, len - 1);
754            }
755            i += len - 1;
756        }
757        // Else: literal already pushed
758
759        i += 1;
760    }
761
762    // Only use LZ77 if savings exceed threshold
763    let threshold = total_symbols as f32 * 0.2 + 16.0;
764    #[cfg(feature = "debug-tokens")]
765    eprintln!(
766        "[LZ77-backref] bit_decrease={:.1}, threshold={:.1}, tokens: {} -> {}, matches={}",
767        bit_decrease,
768        threshold,
769        total_symbols,
770        out.len(),
771        out.iter().filter(|t| t.is_lz77_length()).count()
772    );
773    if bit_decrease > threshold {
774        lz77.enabled = true;
775        Some((out, lz77))
776    } else {
777        None
778    }
779}
780
781/// Apply RLE-based LZ77 compression to a token stream.
782///
783/// Scans for runs of consecutive identical values. When a run is long enough
784/// and the LZ77 encoding is cheaper, replaces the run with:
785/// 1. An LZ77 length token (context = original, value = run_len - min_length, is_lz77_length = true)
786/// 2. A distance token (context = distance_context, value = 0 meaning "repeat previous")
787///
788/// Returns `Some((transformed_tokens, params))` if LZ77 is beneficial,
789/// or `None` if the savings are insufficient.
790pub fn apply_lz77_rle(
791    tokens: &[Token],
792    num_contexts: usize,
793    force_huffman: bool,
794    distance_multiplier: i32,
795) -> Option<(Vec<Token>, Lz77Params)> {
796    if tokens.is_empty() {
797        return None;
798    }
799
800    let mut lz77 = Lz77Params::new(num_contexts, force_huffman);
801
802    // Compute the distance symbol that encodes distance=1 (repeat previous value).
803    // When dist_multiplier == 0: decoder uses distance_sym directly, so sym=0 → distance=0+1=1.
804    // When dist_multiplier > 0: decoder uses special distance table.
805    //   SPECIAL_DISTANCES[1] = (1, 0) → distance = 1 + dm*0 = 1 for any dm.
806    //   SPECIAL_DISTANCES[0] = (0, 1) → distance = 0 + dm*1 = dm (WRONG for RLE).
807    let rle_distance_symbol: u32 = if distance_multiplier > 0 { 1 } else { 0 };
808
809    // First pass: build cost estimator from the original tokens (no LZ77 tokens yet).
810    // We pass the original tokens to estimate costs, matching libjxl.
811    let sce = SymbolCostEstimator::new(num_contexts, force_huffman, tokens, &lz77);
812
813    // Compute cumulative bit costs for original stream.
814    let mut sym_cost = vec![0.0f32; tokens.len() + 1];
815    for (i, token) in tokens.iter().enumerate() {
816        let e = UintCoder::encode(token.value);
817        let cost = sce.symbol_cost(token.context() as usize, e.token as usize) + e.nbits as f32;
818        sym_cost[i + 1] = sym_cost[i] + cost;
819    }
820
821    let mut out = Vec::with_capacity(tokens.len());
822    let mut bit_decrease: f32 = 0.0;
823    let total_symbols = tokens.len();
824
825    let mut i = 0;
826    while i < tokens.len() {
827        // Count consecutive identical values starting from the PREVIOUS token
828        // (matching libjxl: "if (i > 0) { ... in[i+num_to_copy].value != in[i-1].value }")
829        let mut num_to_copy = 0;
830        if i > 0 {
831            let prev_value = tokens[i - 1].value;
832            while i + num_to_copy < tokens.len() && tokens[i + num_to_copy].value == prev_value {
833                num_to_copy += 1;
834            }
835        }
836
837        if num_to_copy == 0 {
838            out.push(tokens[i]);
839            i += 1;
840            continue;
841        }
842
843        // Cost of encoding the run literally
844        let literal_cost = sym_cost[i + num_to_copy] - sym_cost[i];
845
846        // Cost of LZ77 encoding (rough estimate matching libjxl)
847        let lz77_cost = if num_to_copy >= lz77.min_length as usize {
848            let lz77_len = num_to_copy - lz77.min_length as usize;
849            // CeilLog2Nonzero(lz77_len + 1) + 1 (for distance)
850            ceil_log2_nonzero((lz77_len + 1) as u32) as f32 + 1.0
851        } else {
852            0.0
853        };
854
855        if num_to_copy < lz77.min_length as usize || literal_cost <= lz77_cost {
856            // Not worth encoding as LZ77, emit literal tokens
857            for j in 0..num_to_copy {
858                out.push(tokens[i + j]);
859            }
860            i += num_to_copy;
861            continue;
862        }
863
864        // Emit LZ77 length token
865        let lz77_len = (num_to_copy - lz77.min_length as usize) as u32;
866        out.push(Token::lz77_length(tokens[i].context(), lz77_len));
867
868        // Emit distance token encoding distance=1 (repeat previous value)
869        out.push(Token::new(lz77.distance_context, rle_distance_symbol));
870
871        bit_decrease += literal_cost - lz77_cost;
872        i += num_to_copy;
873    }
874
875    // Only use LZ77 if savings exceed threshold (matching libjxl)
876    let threshold = total_symbols as f32 * 0.2 + 16.0;
877    #[cfg(feature = "debug-tokens")]
878    eprintln!(
879        "[LZ77-RLE] bit_decrease={:.1}, threshold={:.1}, tokens: {} -> {}, runs_found={}",
880        bit_decrease,
881        threshold,
882        total_symbols,
883        out.len(),
884        out.iter().filter(|t| t.is_lz77_length()).count()
885    );
886    if bit_decrease > threshold {
887        lz77.enabled = true;
888        Some((out, lz77))
889    } else {
890        None
891    }
892}
893
894/// CeilLog2Nonzero matching libjxl's implementation.
895fn ceil_log2_nonzero(x: u32) -> u32 {
896    debug_assert!(x > 0);
897    let floor = 31 - x.leading_zeros();
898    if x.is_power_of_two() {
899        floor
900    } else {
901        floor + 1
902    }
903}
904
905/// LZ77 method selection.
906#[derive(Debug, Clone, Copy, PartialEq, Eq, Default)]
907pub enum Lz77Method {
908    /// RLE-only: only matches consecutive identical values (distance = 1).
909    /// Fast but limited compression on photographic content.
910    #[default]
911    Rle,
912    /// Full backward references with hash chains (greedy matching).
913    /// Finds matches at arbitrary distances within a sliding window.
914    /// 1-3% better compression on photos, slower.
915    Greedy,
916    /// Optimal backward references via Viterbi DP (from libjxl ApplyLZ77_Optimal).
917    /// Considers all viable matches at each position and finds the minimum-cost
918    /// parse via dynamic programming. Best compression, slowest.
919    Optimal,
920}
921
922/// Apply LZ77 compression using the specified method.
923///
924/// - `Lz77Method::Rle`: RLE-only (fast, limited compression)
925/// - `Lz77Method::Greedy`: Hash chain backward references (slower, better compression)
926/// - `Lz77Method::Optimal`: Viterbi DP optimal parse (slowest, best compression)
927///
928/// For photographic content, `Greedy` typically provides 1-3% additional compression
929/// over RLE-only. `Optimal` finds the minimum-cost parse via dynamic programming.
930///
931/// Returns `Some((transformed_tokens, params))` if LZ77 is beneficial,
932/// or `None` if the savings are insufficient.
933pub fn apply_lz77(
934    tokens: &[Token],
935    num_contexts: usize,
936    force_huffman: bool,
937    method: Lz77Method,
938    distance_multiplier: i32,
939) -> Option<(Vec<Token>, Lz77Params)> {
940    match method {
941        Lz77Method::Rle => apply_lz77_rle(tokens, num_contexts, force_huffman, distance_multiplier),
942        Lz77Method::Greedy => {
943            apply_lz77_backref(tokens, num_contexts, force_huffman, distance_multiplier)
944        }
945        Lz77Method::Optimal => {
946            apply_lz77_optimal(tokens, num_contexts, force_huffman, distance_multiplier)
947        }
948    }
949}
950
951/// Apply optimal LZ77 with Viterbi DP parsing (from libjxl `ApplyLZ77_Optimal`).
952///
953/// Uses dynamic programming to find the minimum-cost parse of the token stream.
954/// First runs greedy LZ77 to build a cost model, then uses that model with
955/// forward-pass DP to find the optimal literal/match decisions at each position.
956///
957/// Returns `Some((transformed_tokens, params))` if LZ77 is beneficial,
958/// or `None` if the savings are insufficient.
959pub fn apply_lz77_optimal(
960    tokens: &[Token],
961    num_contexts: usize,
962    force_huffman: bool,
963    distance_multiplier: i32,
964) -> Option<(Vec<Token>, Lz77Params)> {
965    if tokens.is_empty() {
966        return None;
967    }
968
969    // Step 1: Run greedy LZ77 to get a cost estimate.
970    // If greedy doesn't help, optimal won't either.
971    let greedy_result =
972        apply_lz77_backref(tokens, num_contexts, force_huffman, distance_multiplier);
973    let greedy_tokens = match &greedy_result {
974        Some((t, _)) => t,
975        None => return None,
976    };
977
978    let mut lz77 = Lz77Params::new(num_contexts, force_huffman);
979    lz77.enabled = true;
980
981    // Step 2: Build cost estimator from greedy result (num_contexts + 1 for distance ctx).
982    let sce = SymbolCostEstimator::new(num_contexts + 1, force_huffman, greedy_tokens, &lz77);
983
984    // Step 3: Compute cumulative symbol costs for the original (non-LZ77) stream.
985    let mut sym_cost = vec![0.0f32; tokens.len() + 1];
986    for (i, token) in tokens.iter().enumerate() {
987        let e = UintCoder::encode(token.value);
988        let cost = sce.symbol_cost(token.context() as usize, e.token as usize) + e.nbits as f32;
989        sym_cost[i + 1] = sym_cost[i] + cost;
990    }
991
992    // Step 4: Forward DP pass.
993    let max_distance = tokens.len();
994    let min_length = lz77.min_length as usize;
995    let max_length = tokens.len();
996
997    let mut window_size = 1usize;
998    while window_size < max_distance && window_size < WINDOW_SIZE {
999        window_size <<= 1;
1000    }
1001
1002    let mut chain = HashChain::new(
1003        tokens,
1004        window_size,
1005        min_length,
1006        max_length,
1007        distance_multiplier,
1008    );
1009
1010    // MatchInfo for backtrace: len=1 means literal, dist_symbol stored as +1 (0 = literal).
1011    struct PrefixInfo {
1012        len: u32,
1013        dist_symbol: u32, // 0 = literal, >0 = LZ77 match (actual dist_symbol + 1)
1014        ctx: u32,
1015        total_cost: f32,
1016    }
1017
1018    let n = tokens.len();
1019    let mut prefix_costs: Vec<PrefixInfo> = (0..=n)
1020        .map(|_| PrefixInfo {
1021            len: 0,
1022            dist_symbol: 0,
1023            ctx: 0,
1024            total_cost: f32::MAX,
1025        })
1026        .collect();
1027    prefix_costs[0].total_cost = 0.0;
1028
1029    let mut rle_length = 0usize;
1030    let mut skip_lz77 = 0usize;
1031    let mut dist_symbols: Vec<u32> = Vec::new();
1032
1033    for i in 0..n {
1034        chain.update(i);
1035
1036        // Literal cost
1037        let lit_cost = prefix_costs[i].total_cost + sym_cost[i + 1] - sym_cost[i];
1038        if prefix_costs[i + 1].total_cost > lit_cost {
1039            prefix_costs[i + 1].dist_symbol = 0;
1040            prefix_costs[i + 1].len = 1;
1041            prefix_costs[i + 1].ctx = tokens[i].context();
1042            prefix_costs[i + 1].total_cost = lit_cost;
1043        }
1044
1045        if skip_lz77 > 0 {
1046            skip_lz77 -= 1;
1047            continue;
1048        }
1049
1050        // Collect all matches: for each length, keep the cheapest dist_symbol.
1051        dist_symbols.clear();
1052        chain.find_matches(i, max_distance, |len, dist_symbol| {
1053            if dist_symbols.len() <= len {
1054                dist_symbols.resize(len + 1, dist_symbol as u32);
1055            }
1056            if (dist_symbol as u32) < dist_symbols[len] {
1057                dist_symbols[len] = dist_symbol as u32;
1058            }
1059        });
1060
1061        if dist_symbols.len() <= min_length {
1062            continue;
1063        }
1064
1065        // Normalize: for each length, use the best dist_symbol from any longer match.
1066        {
1067            let mut best_cost = dist_symbols[dist_symbols.len() - 1];
1068            for j in (min_length..dist_symbols.len()).rev() {
1069                if dist_symbols[j] < best_cost {
1070                    best_cost = dist_symbols[j];
1071                }
1072                dist_symbols[j] = best_cost;
1073            }
1074        }
1075
1076        // Evaluate each match length.
1077        for (j, &dsym) in dist_symbols.iter().enumerate().skip(min_length) {
1078            let target = i + j;
1079            if target > n {
1080                break;
1081            }
1082            let lz77_cost =
1083                sce.len_cost(tokens[i].context() as usize, (j - min_length) as u32, &lz77)
1084                    + sce.dist_cost_sce(dsym, &lz77);
1085            let cost = prefix_costs[i].total_cost + lz77_cost;
1086            if prefix_costs[target].total_cost > cost {
1087                prefix_costs[target].len = j as u32;
1088                prefix_costs[target].dist_symbol = dsym + 1; // +1 to distinguish from literal
1089                prefix_costs[target].ctx = tokens[i].context();
1090                prefix_costs[target].total_cost = cost;
1091            }
1092        }
1093
1094        // RLE skip optimization: avoid O(n^2) on long runs of same distance.
1095        let last_dist = dist_symbols[dist_symbols.len() - 1];
1096        if (last_dist == 0 && distance_multiplier == 0)
1097            || (last_dist == 1 && distance_multiplier != 0)
1098        {
1099            rle_length += 1;
1100        } else {
1101            rle_length = 0;
1102        }
1103        if rle_length >= 8 && dist_symbols.len() > 9 {
1104            skip_lz77 = dist_symbols.len() - 10;
1105            rle_length = 0;
1106        }
1107    }
1108
1109    // Step 5: Backtrace from end to beginning.
1110    let mut out = Vec::with_capacity(n);
1111    let mut pos = n;
1112    while pos > 0 {
1113        let info = &prefix_costs[pos];
1114        let is_lz77 = info.dist_symbol != 0;
1115
1116        if is_lz77 {
1117            let dist_symbol = info.dist_symbol - 1;
1118            out.push(Token::new(lz77.distance_context, dist_symbol));
1119        }
1120
1121        let val = if is_lz77 {
1122            info.len - min_length as u32
1123        } else {
1124            tokens[pos - 1].value
1125        };
1126        let mut tok = Token::new(info.ctx, val);
1127        tok.set_lz77_length(is_lz77);
1128        out.push(tok);
1129
1130        pos -= info.len as usize;
1131    }
1132
1133    out.reverse();
1134    Some((out, lz77))
1135}
1136
1137/// Try both LZ77 methods and return the one with better compression.
1138///
1139/// This is useful when you want the best compression regardless of speed.
1140/// Returns the method that produces fewer output tokens, or None if neither
1141/// method provides sufficient savings.
1142#[allow(dead_code)] // Utility function for advanced users
1143pub fn apply_lz77_best(
1144    tokens: &[Token],
1145    num_contexts: usize,
1146    force_huffman: bool,
1147    distance_multiplier: i32,
1148) -> Option<(Vec<Token>, Lz77Params)> {
1149    let rle_result = apply_lz77_rle(tokens, num_contexts, force_huffman, distance_multiplier);
1150    let backref_result =
1151        apply_lz77_backref(tokens, num_contexts, force_huffman, distance_multiplier);
1152
1153    match (&rle_result, &backref_result) {
1154        (Some((rle_tokens, _)), Some((backref_tokens, _))) => {
1155            // Return whichever produces fewer tokens
1156            if backref_tokens.len() <= rle_tokens.len() {
1157                backref_result
1158            } else {
1159                rle_result
1160            }
1161        }
1162        (Some(_), None) => rle_result,
1163        (None, Some(_)) => backref_result,
1164        (None, None) => None,
1165    }
1166}
1167
1168#[cfg(test)]
1169mod tests {
1170    use super::*;
1171
1172    #[test]
1173    fn test_ceil_log2_nonzero() {
1174        assert_eq!(ceil_log2_nonzero(1), 0);
1175        assert_eq!(ceil_log2_nonzero(2), 1);
1176        assert_eq!(ceil_log2_nonzero(3), 2);
1177        assert_eq!(ceil_log2_nonzero(4), 2);
1178        assert_eq!(ceil_log2_nonzero(5), 3);
1179        assert_eq!(ceil_log2_nonzero(8), 3);
1180        assert_eq!(ceil_log2_nonzero(9), 4);
1181    }
1182
1183    #[test]
1184    fn test_no_rle_on_short_stream() {
1185        // Very short streams shouldn't trigger LZ77
1186        let tokens = vec![Token::new(0, 5), Token::new(0, 5), Token::new(0, 5)];
1187        assert!(apply_lz77_rle(&tokens, 1, false, 0).is_none());
1188    }
1189
1190    #[test]
1191    fn test_rle_on_long_run() {
1192        // Long run of identical values should trigger LZ77
1193        let mut tokens = Vec::new();
1194        // Need a "previous" token, then a long run of identical values
1195        tokens.push(Token::new(0, 5));
1196        for _ in 0..200 {
1197            tokens.push(Token::new(0, 5));
1198        }
1199
1200        let result = apply_lz77_rle(&tokens, 1, false, 0);
1201        if let Some((lz77_tokens, params)) = result {
1202            assert!(params.enabled);
1203            // Should be much shorter than the original
1204            assert!(lz77_tokens.len() < tokens.len());
1205            // Should contain at least one LZ77 length token
1206            assert!(lz77_tokens.iter().any(|t| t.is_lz77_length()));
1207        }
1208        // If None, that's OK — the threshold might not be met for this particular cost estimate
1209    }
1210
1211    #[test]
1212    fn test_rle_preserves_non_runs() {
1213        // Mixed content: some runs, some non-runs
1214        let mut tokens = Vec::new();
1215        // Non-repeating prefix
1216        for i in 0..10 {
1217            tokens.push(Token::new(0, i));
1218        }
1219        // Long run
1220        for _ in 0..100 {
1221            tokens.push(Token::new(0, 42));
1222        }
1223        // Non-repeating suffix
1224        for i in 0..10 {
1225            tokens.push(Token::new(0, i + 100));
1226        }
1227
1228        if let Some((lz77_tokens, params)) = apply_lz77_rle(&tokens, 1, false, 0) {
1229            assert!(params.enabled);
1230            assert!(lz77_tokens.len() < tokens.len());
1231            // The first token should be preserved literally
1232            assert_eq!(lz77_tokens[0].value, 0);
1233            assert!(!lz77_tokens[0].is_lz77_length());
1234        }
1235    }
1236
1237    #[test]
1238    fn test_empty_stream() {
1239        assert!(apply_lz77_rle(&[], 1, false, 0).is_none());
1240    }
1241
1242    // Tests for backward-reference LZ77
1243
1244    #[test]
1245    fn test_backref_empty_stream() {
1246        assert!(apply_lz77_backref(&[], 1, false, 0).is_none());
1247    }
1248
1249    #[test]
1250    fn test_backref_short_stream() {
1251        // Very short streams shouldn't trigger LZ77
1252        let tokens = vec![Token::new(0, 5), Token::new(0, 5), Token::new(0, 5)];
1253        assert!(apply_lz77_backref(&tokens, 1, false, 0).is_none());
1254    }
1255
1256    #[test]
1257    fn test_backref_on_repeating_pattern() {
1258        // Pattern that repeats at distance > 1 (not just RLE)
1259        // Pattern: [A, B, C, A, B, C, A, B, C, ...]
1260        let mut tokens = Vec::new();
1261        for _ in 0..100 {
1262            tokens.push(Token::new(0, 10));
1263            tokens.push(Token::new(0, 20));
1264            tokens.push(Token::new(0, 30));
1265        }
1266
1267        let result = apply_lz77_backref(&tokens, 1, false, 0);
1268        if let Some((lz77_tokens, params)) = result {
1269            assert!(params.enabled);
1270            // Should be significantly shorter due to backward references
1271            assert!(
1272                lz77_tokens.len() < tokens.len(),
1273                "backref should compress pattern: {} vs {}",
1274                lz77_tokens.len(),
1275                tokens.len()
1276            );
1277            // Should have LZ77 length tokens
1278            assert!(lz77_tokens.iter().any(|t| t.is_lz77_length()));
1279        }
1280    }
1281
1282    #[test]
1283    fn test_backref_finds_longer_matches_than_rle() {
1284        // Pattern where backref can find matches that RLE cannot
1285        // [1, 2, 3, 4, 5, 1, 2, 3, 4, 5, 1, 2, 3, 4, 5, ...]
1286        let mut tokens = Vec::new();
1287        for _ in 0..50 {
1288            for j in 1..=5 {
1289                tokens.push(Token::new(0, j));
1290            }
1291        }
1292
1293        let rle_result = apply_lz77_rle(&tokens, 1, false, 0);
1294        let backref_result = apply_lz77_backref(&tokens, 1, false, 0);
1295
1296        // RLE should not find matches (no consecutive identical values)
1297        // Backref should find matches at distance 5
1298        match (&rle_result, &backref_result) {
1299            (None, Some((backref_tokens, _))) => {
1300                // This is the expected case: RLE finds nothing, backref does
1301                assert!(backref_tokens.len() < tokens.len());
1302            }
1303            (Some((rle_tokens, _)), Some((backref_tokens, _))) => {
1304                // If both activate, backref should do better or equal
1305                assert!(backref_tokens.len() <= rle_tokens.len());
1306            }
1307            _ => {
1308                // Either both fail (acceptable for small patterns) or both succeed
1309            }
1310        }
1311    }
1312
1313    #[test]
1314    fn test_backref_with_distance_multiplier() {
1315        // Test that special distance codes work with distance multiplier
1316        // When multiplier is non-zero, distances like image_width are encoded more efficiently
1317        let mut tokens = Vec::new();
1318        let image_width = 64;
1319
1320        // Create pattern that repeats at image_width distance (previous row)
1321        for _row in 0..20 {
1322            for col in 0..image_width {
1323                // Same value for same column across rows
1324                tokens.push(Token::new(0, (col % 16) as u32));
1325            }
1326        }
1327
1328        let _result_no_mult = apply_lz77_backref(&tokens, 1, false, 0);
1329        let result_with_mult = apply_lz77_backref(&tokens, 1, false, image_width);
1330
1331        // Both should find matches; with multiplier might be more efficient
1332        // but the main test is that it doesn't crash and produces valid output
1333        if let Some((tokens_mult, params)) = result_with_mult {
1334            assert!(params.enabled);
1335            assert!(tokens_mult.len() < tokens.len());
1336        }
1337    }
1338
1339    #[test]
1340    fn test_special_distance() {
1341        // Test special distance calculation
1342        // kSpecialDistances[0] = [0, 1] -> distance = 0 + 1*multiplier = multiplier
1343        assert_eq!(special_distance(0, 64), 64);
1344        // kSpecialDistances[1] = [1, 0] -> distance = 1 + 0*multiplier = 1
1345        assert_eq!(special_distance(1, 64), 1);
1346        // kSpecialDistances[2] = [1, 1] -> distance = 1 + 1*multiplier = 65
1347        assert_eq!(special_distance(2, 64), 65);
1348        // kSpecialDistances[3] = [-1, 1] -> distance = -1 + 1*multiplier = 63
1349        assert_eq!(special_distance(3, 64), 63);
1350    }
1351
1352    #[test]
1353    fn test_len_cost() {
1354        // Verify len_cost doesn't panic on various inputs
1355        for len in 0..1000 {
1356            let cost = len_cost(len);
1357            assert!(cost >= 0.0, "len_cost({}) should be non-negative", len);
1358            assert!(cost < 100.0, "len_cost({}) should be reasonable", len);
1359        }
1360    }
1361
1362    #[test]
1363    fn test_dist_cost() {
1364        // Verify dist_cost doesn't panic on various inputs
1365        for dist in 0..10000 {
1366            let cost = dist_cost(dist);
1367            assert!(cost >= 0.0, "dist_cost({}) should be non-negative", dist);
1368            assert!(cost < 100.0, "dist_cost({}) should be reasonable", dist);
1369        }
1370    }
1371
1372    #[test]
1373    fn test_apply_lz77_method_enum() {
1374        let mut tokens = Vec::new();
1375        tokens.push(Token::new(0, 5));
1376        for _ in 0..200 {
1377            tokens.push(Token::new(0, 5));
1378        }
1379
1380        // Test RLE method
1381        let rle_result = apply_lz77(&tokens, 1, false, Lz77Method::Rle, 0);
1382        if let Some((_, params)) = &rle_result {
1383            assert!(params.enabled);
1384        }
1385
1386        // Test Greedy method
1387        let greedy_result = apply_lz77(&tokens, 1, false, Lz77Method::Greedy, 0);
1388        if let Some((_, params)) = &greedy_result {
1389            assert!(params.enabled);
1390        }
1391    }
1392
1393    #[test]
1394    fn test_apply_lz77_best() {
1395        // Pattern where backref should do better
1396        let mut tokens = Vec::new();
1397        for _ in 0..50 {
1398            for j in 1..=10 {
1399                tokens.push(Token::new(0, j));
1400            }
1401        }
1402
1403        let best_result = apply_lz77_best(&tokens, 1, false, 0);
1404        // Should pick the better method (likely backref for this pattern)
1405        if let Some((best_tokens, params)) = best_result {
1406            assert!(params.enabled);
1407            assert!(best_tokens.len() < tokens.len());
1408        }
1409    }
1410
1411    #[test]
1412    fn test_hash_chain_basic() {
1413        // Test hash chain finds simple matches
1414        let tokens = vec![
1415            Token::new(0, 10),
1416            Token::new(0, 20),
1417            Token::new(0, 30),
1418            Token::new(0, 40), // Different sequence
1419            Token::new(0, 10),
1420            Token::new(0, 20),
1421            Token::new(0, 30), // Repeats tokens 0-2
1422        ];
1423
1424        let mut chain = HashChain::new(&tokens, 16, 3, 100, 0);
1425        // Update chain for all positions
1426        for i in 0..tokens.len() {
1427            chain.update(i);
1428        }
1429
1430        // At position 4, should find match at position 0 (distance 4)
1431        let (dist_symbol, len) = chain.find_match(4, 10);
1432        assert!(len >= 3, "should find match of length >= 3, got {}", len);
1433        // dist_symbol should encode distance 4 (no special distances with multiplier=0)
1434        // Special distances: 0 entries since multiplier=0
1435        // So dist_symbol = num_special_distances + dist - 1 = 0 + 4 - 1 = 3
1436        assert_eq!(dist_symbol, 3, "distance symbol for dist=4 should be 3");
1437    }
1438}