Skip to main content

jxl_encoder/entropy_coding/
lz77.rs

1// Copyright (c) Imazen LLC and the JPEG XL Project Authors.
2// Algorithms and constants derived from libjxl (BSD-3-Clause).
3// Licensed under AGPL-3.0-or-later. Commercial licenses at https://www.imazen.io/pricing
4
5//! LZ77 backward references for entropy coding.
6//!
7//! Implements both RLE-only and full backward-reference LZ77 methods from libjxl.
8//! - `apply_lz77_rle`: RLE-only (consecutive identical values)
9//! - `apply_lz77_backref`: Full backward references with hash chains (greedy matching)
10//!
11//! The backward-reference method uses hash chains to find matches at arbitrary
12//! distances within a sliding window, providing 1-3% compression improvement on
13//! photographic content compared to RLE-only.
14
15use hashbrown::HashMap;
16
17use super::token::{Lz77UintCoder, Token, UintCoder};
18
19/// Maximum window size for LZ77 matching (1MB, matches libjxl).
20const WINDOW_SIZE: usize = 1 << 20;
21
22/// Number of special distance codes from WebP lossless.
23const NUM_SPECIAL_DISTANCES: usize = 120;
24
25/// Special distance codes from WebP lossless.
26/// Each entry is [offset, multiplier] where distance = offset + multiplier * image_width.
27/// These encode common 2D patterns (horizontal, vertical, diagonal) compactly.
28#[rustfmt::skip]
29const SPECIAL_DISTANCES: [[i8; 2]; NUM_SPECIAL_DISTANCES] = [
30    [0, 1],  [1, 0],  [1, 1],  [-1, 1], [0, 2],  [2, 0],  [1, 2],  [-1, 2],
31    [2, 1],  [-2, 1], [2, 2],  [-2, 2], [0, 3],  [3, 0],  [1, 3],  [-1, 3],
32    [3, 1],  [-3, 1], [2, 3],  [-2, 3], [3, 2],  [-3, 2], [0, 4],  [4, 0],
33    [1, 4],  [-1, 4], [4, 1],  [-4, 1], [3, 3],  [-3, 3], [2, 4],  [-2, 4],
34    [4, 2],  [-4, 2], [0, 5],  [3, 4],  [-3, 4], [4, 3],  [-4, 3], [5, 0],
35    [1, 5],  [-1, 5], [5, 1],  [-5, 1], [2, 5],  [-2, 5], [5, 2],  [-5, 2],
36    [4, 4],  [-4, 4], [3, 5],  [-3, 5], [5, 3],  [-5, 3], [0, 6],  [6, 0],
37    [1, 6],  [-1, 6], [6, 1],  [-6, 1], [2, 6],  [-2, 6], [6, 2],  [-6, 2],
38    [4, 5],  [-4, 5], [5, 4],  [-5, 4], [3, 6],  [-3, 6], [6, 3],  [-6, 3],
39    [0, 7],  [7, 0],  [1, 7],  [-1, 7], [5, 5],  [-5, 5], [7, 1],  [-7, 1],
40    [4, 6],  [-4, 6], [6, 4],  [-6, 4], [2, 7],  [-2, 7], [7, 2],  [-7, 2],
41    [3, 7],  [-3, 7], [7, 3],  [-7, 3], [5, 6],  [-5, 6], [6, 5],  [-6, 5],
42    [8, 0],  [4, 7],  [-4, 7], [7, 4],  [-7, 4], [8, 1],  [8, 2],  [6, 6],
43    [-6, 6], [8, 3],  [5, 7],  [-5, 7], [7, 5],  [-7, 5], [8, 4],  [6, 7],
44    [-6, 7], [7, 6],  [-7, 6], [8, 5],  [7, 7],  [-7, 7], [8, 6],  [8, 7],
45];
46
47/// Compute special distance from code index and distance multiplier (image width).
48#[inline]
49fn special_distance(index: usize, multiplier: i32) -> i32 {
50    SPECIAL_DISTANCES[index][0] as i32 + multiplier * SPECIAL_DISTANCES[index][1] as i32
51}
52
53/// Empirical cost table for LZ77 length encoding (from libjxl).
54/// Indexed by token value from HybridUintConfig(1, 0, 0).
55#[rustfmt::skip]
56#[allow(clippy::excessive_precision)]
57const LEN_COST_TABLE: [f32; 17] = [
58    2.797667318563126,  3.213177690381199,  2.5706009246743737,
59    2.408392498667534,  2.829649191872326,  3.3923087753324577,
60    4.029267451554331,  4.415576699706408,  4.509357574741465,
61    9.21481543803004,   10.020590190114898, 11.858671627804766,
62    12.45853300490526,  11.713105831990857, 12.561996324849314,
63    13.775477692278367, 13.174027068768641,
64];
65
66/// Empirical cost table for LZ77 distance encoding (from libjxl).
67/// Indexed by token value from HybridUintConfig(7, 0, 0).
68#[rustfmt::skip]
69#[allow(clippy::excessive_precision)]
70const DIST_COST_TABLE: [f32; 128] = [
71    6.368282626312716,  5.680793277090298,  8.347404197105247,
72    7.641619201599141,  6.914328374119438,  7.959808291537444,
73    8.70023120759855,   8.71378518934703,   9.379132523982769,
74    9.110472749092708,  9.159029569270908,  9.430936766731973,
75    7.278284055315169,  7.8278514904267755, 10.026641158289236,
76    9.976049229827066,  9.64351607048908,   9.563403863480442,
77    10.171474111762747, 10.45950155077234,  9.994813912104219,
78    10.322524683741156, 8.465808729388186,  8.756254166066853,
79    10.160930174662234, 10.247329273413435, 10.04090403724809,
80    10.129398517544082, 9.342311691539546,  9.07608009102374,
81    10.104799540677513, 10.378079384990906, 10.165828974075072,
82    10.337595322341553, 7.940557464567944,  10.575665823319431,
83    11.023344321751955, 10.736144698831827, 11.118277044595054,
84    7.468468230648442,  10.738305230932939, 10.906980780216568,
85    10.163468216353817, 10.17805759656433,  11.167283670483565,
86    11.147050200274544, 10.517921919244333, 10.651764778156886,
87    10.17074446448919,  11.217636876224745, 11.261630721139484,
88    11.403140815247259, 10.892472096873417, 11.1859607804481,
89    8.017346947551262,  7.895143720278828,  11.036577113822025,
90    11.170562110315794, 10.326988722591086, 10.40872184751056,
91    11.213498225466386, 11.30580635516863,  10.672272515665442,
92    10.768069466228063, 11.145257364153565, 11.64668307145549,
93    10.593156194627339, 11.207499484844943, 10.767517766396908,
94    10.826629811407042, 10.737764794499988, 10.6200448518045,
95    10.191315385198092, 8.468384171390085,  11.731295299170432,
96    11.824619886654398, 10.41518844301179,  10.16310536548649,
97    10.539423685097576, 10.495136599328031, 10.469112847728267,
98    11.72057686174922,  10.910326337834674, 11.378921834673758,
99    11.847759036098536, 11.92071647623854,  10.810628276345282,
100    11.008601085273893, 11.910326337834674, 11.949212023423133,
101    11.298614839104337, 11.611603659010392, 10.472930394619985,
102    11.835564720850282, 11.523267392285337, 12.01055816679611,
103    8.413029688994023,  11.895784139536406, 11.984679534970505,
104    11.220654278717394, 11.716311684833672, 10.61036646226114,
105    10.89849965960364,  10.203762898863669, 10.997560826267238,
106    11.484217379438984, 11.792836176993665, 12.24310468755171,
107    11.464858097919262, 12.212747017409377, 11.425595666074955,
108    11.572048533398757, 12.742093965163013, 11.381874288645637,
109    12.191870445817015, 11.683156920035426, 11.152442115262197,
110    11.90303691580457,  11.653292787169159, 11.938615382266098,
111    16.970641701570223, 16.853602280380002, 17.26240782594733,
112    16.644655390108507, 17.14310889757499,  16.910935455445955,
113    17.505678976959697, 17.213498225466388,
114];
115
116/// Empirical cost for LZ77 length encoding.
117fn len_cost(len: u32) -> f32 {
118    // HybridUintConfig(1, 0, 0): token = 1 + floor_log2(len) for len >= 1
119    let (tok, nbits) = if len == 0 {
120        (0u32, 0u32)
121    } else {
122        let n = 31 - len.leading_zeros();
123        (1 + n, n)
124    };
125    let table_size = LEN_COST_TABLE.len();
126    let tok_idx = (tok as usize).min(table_size - 1);
127    LEN_COST_TABLE[tok_idx] + nbits as f32
128}
129
130/// Empirical cost for LZ77 distance encoding.
131fn dist_cost(dist: u32) -> f32 {
132    // HybridUintConfig(7, 0, 0): different split point
133    let (tok, nbits) = hybrid_uint_encode_7_0_0(dist);
134    let table_size = DIST_COST_TABLE.len();
135    let tok_idx = (tok as usize).min(table_size - 1);
136    DIST_COST_TABLE[tok_idx] + nbits as f32
137}
138
139/// HybridUint encoding with config (7, 0, 0) for distance symbols.
140fn hybrid_uint_encode_7_0_0(value: u32) -> (u32, u32) {
141    // split = 7, msb_in_token = 0, lsb_in_token = 0
142    // Values 0-6: direct encoding
143    // Values >= 7: floor_log2 encoding
144    if value < 7 {
145        (value, 0)
146    } else {
147        let n = 31 - value.leading_zeros();
148        let tok = 7 + n - 3; // Offset for values >= 7
149        (tok, n)
150    }
151}
152
153/// LZ77 parameters serialized in the entropy code header.
154#[derive(Debug, Clone)]
155pub struct Lz77Params {
156    pub enabled: bool,
157    /// Symbols >= min_symbol are LZ77 length tokens.
158    /// ANS: 224, Huffman: 512.
159    pub min_symbol: u32,
160    /// Minimum run length to encode as LZ77. Default: 3.
161    pub min_length: u32,
162    /// Context index for distance tokens (= num_contexts before LZ77).
163    pub distance_context: u32,
164}
165
166impl Lz77Params {
167    pub fn new(num_contexts: usize, force_huffman: bool) -> Self {
168        Self {
169            enabled: false,
170            min_symbol: if force_huffman { 512 } else { 224 },
171            min_length: 3,
172            distance_context: num_contexts as u32,
173        }
174    }
175}
176
177/// Estimate per-symbol bit cost from histograms, matching libjxl's SymbolCostEstimator.
178struct SymbolCostEstimator {
179    /// Flat array: bits[ctx * max_alphabet_size + sym]
180    bits: Vec<f32>,
181    max_alphabet_size: usize,
182}
183
184impl SymbolCostEstimator {
185    fn new(num_contexts: usize, force_huffman: bool, tokens: &[Token], lz77: &Lz77Params) -> Self {
186        const ANS_LOG_TAB_SIZE: f32 = 12.0;
187
188        // Build per-context histograms from the (possibly LZ77-transformed) tokens.
189        let mut counts: Vec<Vec<u32>> = vec![vec![]; num_contexts];
190        let mut total_counts = vec![0u32; num_contexts];
191
192        for token in tokens {
193            let (tok, _nbits) = if token.is_lz77_length {
194                let e = Lz77UintCoder::encode(token.value);
195                (e.token + lz77.min_symbol, e.nbits)
196            } else {
197                let e = UintCoder::encode(token.value);
198                (e.token, e.nbits)
199            };
200            let ctx = token.context as usize;
201            if ctx < num_contexts {
202                let sym = tok as usize;
203                if sym >= counts[ctx].len() {
204                    counts[ctx].resize(sym + 1, 0);
205                }
206                counts[ctx][sym] += 1;
207                total_counts[ctx] += 1;
208            }
209        }
210
211        let max_alphabet_size = counts.iter().map(|c| c.len()).max().unwrap_or(0);
212        let mut bits = vec![0.0f32; num_contexts * max_alphabet_size];
213
214        for ctx in 0..num_contexts {
215            let total = total_counts[ctx];
216            if total == 0 {
217                continue;
218            }
219            let inv_total = 1.0 / (total as f32 + 1e-8);
220            for sym in 0..counts[ctx].len() {
221                let cnt = counts[ctx][sym];
222                let cost = if cnt != 0 && cnt != total {
223                    let p = cnt as f32 * inv_total;
224                    let c = -p.log2();
225                    if force_huffman { c.ceil() } else { c }
226                } else if cnt == 0 {
227                    ANS_LOG_TAB_SIZE // Highest possible cost
228                } else {
229                    0.0 // Single symbol, zero cost
230                };
231                bits[ctx * max_alphabet_size + sym] = cost;
232            }
233        }
234
235        Self {
236            bits,
237            max_alphabet_size,
238        }
239    }
240
241    #[inline]
242    fn symbol_cost(&self, ctx: usize, sym: usize) -> f32 {
243        if sym < self.max_alphabet_size {
244            self.bits[ctx * self.max_alphabet_size + sym]
245        } else {
246            12.0 // ANS_LOG_TAB_SIZE as fallback
247        }
248    }
249
250    /// Cost of adding an LZ77 symbol to a context (penalty for low-entropy contexts).
251    fn add_symbol_cost(&self, ctx: usize) -> f32 {
252        // Compute average cost per symbol in this context
253        let mut total_cost = 0.0f32;
254        let mut total_count = 0u32;
255        for sym in 0..self.max_alphabet_size {
256            let cost = self.bits[ctx * self.max_alphabet_size + sym];
257            if cost < 12.0 {
258                // Only count symbols that exist in the histogram
259                total_cost += cost;
260                total_count += 1;
261            }
262        }
263        if total_count == 0 {
264            return 0.0;
265        }
266        // Higher penalty for contexts with low per-symbol entropy
267        (6.0 - total_cost / total_count as f32).max(0.0)
268    }
269
270    /// Cost of encoding an LZ77 length token using histogram-based estimation.
271    #[allow(dead_code)] // Used for optimal LZ77 mode (future)
272    fn len_cost(&self, ctx: usize, len: u32, lz77: &Lz77Params) -> f32 {
273        // HybridUintConfig(1, 0, 0) for LZ77 length
274        let (tok, nbits) = if len == 0 {
275            (0u32, 0u32)
276        } else {
277            let n = 31 - len.leading_zeros();
278            (1 + n, n)
279        };
280        let sym = tok + lz77.min_symbol;
281        nbits as f32 + self.symbol_cost(ctx, sym as usize)
282    }
283
284    /// Cost of encoding an LZ77 distance token using histogram-based estimation.
285    #[allow(dead_code)] // Used for optimal LZ77 mode (future)
286    fn dist_cost_sce(&self, dist_symbol: u32, lz77: &Lz77Params) -> f32 {
287        let (tok, nbits) = UintCoder::encode(dist_symbol).into();
288        nbits as f32 + self.symbol_cost(lz77.distance_context as usize, tok as usize)
289    }
290}
291
292/// Hash chain for LZ77 match finding.
293///
294/// Uses a sliding window and hash table to efficiently find matching sequences.
295/// Matches libjxl's HashChain implementation in enc_lz77.cc.
296struct HashChain {
297    /// Token values (we only hash on value, not context)
298    data: Vec<u32>,
299    /// Size of token stream
300    size: usize,
301    /// Window size (power of 2)
302    window_size: usize,
303    /// Window mask (window_size - 1)
304    window_mask: usize,
305    /// Minimum match length
306    min_length: usize,
307    /// Maximum match length
308    max_length: usize,
309
310    // Hash table parameters
311    #[allow(dead_code)] // Stored for debugging/reference
312    hash_num_values: usize,
313    hash_mask: usize,
314    hash_shift: u32,
315
316    /// Head of hash chain for each hash value (-1 if empty)
317    head: Vec<i32>,
318    /// Hash chain: next position with same hash
319    chain: Vec<u32>,
320    /// Hash value at each window position (-1 if invalid)
321    val: Vec<i32>,
322
323    // Zero-run optimization
324    /// Head of zero-run chain for each run length
325    headz: Vec<i32>,
326    /// Zero-run chain
327    chainz: Vec<u32>,
328    /// Number of consecutive zeros starting at each position
329    zeros: Vec<u32>,
330    /// Current zero count
331    numzeros: u32,
332
333    /// Map from actual distance to special distance symbol
334    special_dist_table: HashMap<i32, usize>,
335    /// Number of special distances (0 if no multiplier, 120 otherwise)
336    num_special_distances: usize,
337
338    /// Maximum chain length to traverse (limits search time)
339    max_chain_length: u32,
340}
341
342impl HashChain {
343    fn new(
344        tokens: &[Token],
345        window_size: usize,
346        min_length: usize,
347        max_length: usize,
348        distance_multiplier: i32,
349    ) -> Self {
350        let size = tokens.len();
351
352        // Extract just the values
353        let data: Vec<u32> = tokens.iter().map(|t| t.value).collect();
354
355        // Hash table setup
356        let hash_num_values = 32768usize;
357        let hash_mask = hash_num_values - 1;
358        let hash_shift = 5u32;
359
360        let head = vec![-1i32; hash_num_values];
361        let chain: Vec<u32> = (0..window_size as u32).collect(); // Self-reference indicates uninitialized
362        let val = vec![-1i32; window_size];
363
364        // Zero-run optimization
365        let headz = vec![-1i32; window_size + 1];
366        let chainz: Vec<u32> = (0..window_size as u32).collect();
367        let zeros = vec![0u32; window_size];
368
369        // Build special distance table
370        let mut special_dist_table = HashMap::new();
371        let num_special_distances = if distance_multiplier != 0 {
372            // Count down so smallest code wins on ties
373            for i in (0..NUM_SPECIAL_DISTANCES).rev() {
374                let dist = special_distance(i, distance_multiplier);
375                if dist > 0 {
376                    special_dist_table.insert(dist, i);
377                }
378            }
379            NUM_SPECIAL_DISTANCES
380        } else {
381            0
382        };
383
384        Self {
385            data,
386            size,
387            window_size,
388            window_mask: window_size - 1,
389            min_length,
390            max_length,
391            hash_num_values,
392            hash_mask,
393            hash_shift,
394            head,
395            chain,
396            val,
397            headz,
398            chainz,
399            zeros,
400            numzeros: 0,
401            special_dist_table,
402            num_special_distances,
403            max_chain_length: 256,
404        }
405    }
406
407    /// Compute hash of 3 consecutive values starting at pos.
408    fn get_hash(&self, pos: usize) -> u32 {
409        if pos + 2 >= self.size {
410            return 0;
411        }
412        let mut result = 0u32;
413        result ^= self.data[pos] & 0xFFFF;
414        result ^= (self.data[pos + 1] & 0xFFFF) << self.hash_shift;
415        result ^= (self.data[pos + 2] & 0xFFFF) << (self.hash_shift * 2);
416        result & self.hash_mask as u32
417    }
418
419    /// Count consecutive zeros starting at pos.
420    fn count_zeros(&self, pos: usize, prev_zeros: u32) -> u32 {
421        let end = (pos + self.window_size).min(self.size);
422        if prev_zeros > 0 {
423            if prev_zeros >= self.window_mask as u32
424                && self.data[end - 1] == 0
425                && end == pos + self.window_size
426            {
427                return prev_zeros;
428            } else {
429                return prev_zeros - 1;
430            }
431        }
432        let mut num = 0u32;
433        while pos + (num as usize) < end && self.data[pos + (num as usize)] == 0 {
434            num += 1;
435        }
436        num
437    }
438
439    /// Update hash chain with position pos.
440    fn update(&mut self, pos: usize) {
441        let hashval = self.get_hash(pos);
442        let wpos = pos & self.window_mask;
443
444        self.val[wpos] = hashval as i32;
445        if self.head[hashval as usize] != -1 {
446            self.chain[wpos] = self.head[hashval as usize] as u32;
447        }
448        self.head[hashval as usize] = wpos as i32;
449
450        // Update zero count
451        if pos > 0 && self.data[pos] != self.data[pos - 1] {
452            self.numzeros = 0;
453        }
454        self.numzeros = self.count_zeros(pos, self.numzeros);
455
456        self.zeros[wpos] = self.numzeros;
457        if self.headz[self.numzeros as usize] != -1 {
458            self.chainz[wpos] = self.headz[self.numzeros as usize] as u32;
459        }
460        self.headz[self.numzeros as usize] = wpos as i32;
461    }
462
463    /// Update hash chain for multiple positions.
464    fn update_range(&mut self, pos: usize, len: usize) {
465        for i in 0..len {
466            self.update(pos + i);
467        }
468    }
469
470    /// Find best match at position pos.
471    /// Returns (distance_symbol, match_length).
472    fn find_match(&self, pos: usize, max_dist: usize) -> (usize, usize) {
473        let mut best_dist_symbol = 0usize;
474        let mut best_len = 1usize;
475
476        self.find_matches(pos, max_dist, |len, dist_symbol| {
477            if len > best_len || (len == best_len && dist_symbol < best_dist_symbol) {
478                best_len = len;
479                best_dist_symbol = dist_symbol;
480            }
481        });
482
483        (best_dist_symbol, best_len)
484    }
485
486    /// Find all matches at position pos, calling callback for each.
487    fn find_matches<F>(&self, pos: usize, max_dist: usize, mut found_match: F)
488    where
489        F: FnMut(usize, usize),
490    {
491        let wpos = pos & self.window_mask;
492        let hashval = self.get_hash(pos);
493        let mut hashpos = self.chain[wpos];
494
495        let mut prev_dist = 0i32;
496        let end = (pos + self.max_length).min(self.size);
497        let mut chain_length = 0u32;
498        let mut best_len = 0usize;
499
500        loop {
501            // Compute distance from current position to hash chain position
502            let dist = if hashpos as usize <= wpos {
503                wpos - hashpos as usize
504            } else {
505                wpos + self.window_mask + 1 - hashpos as usize
506            };
507
508            if (dist as i32) < prev_dist {
509                break;
510            }
511            prev_dist = dist as i32;
512
513            if dist > 0 && dist <= max_dist {
514                // Compare sequences
515                let mut i = pos;
516                let mut j = pos - dist;
517
518                // Zero-run optimization: skip known zeros
519                if self.numzeros > 3 {
520                    let r =
521                        ((self.numzeros - 1) as usize).min(self.zeros[hashpos as usize] as usize);
522                    let skip = if i + r >= end { end - i - 1 } else { r };
523                    i += skip;
524                    j += skip;
525                }
526
527                // Extend match
528                while i < end && self.data[i] == self.data[j] {
529                    i += 1;
530                    j += 1;
531                }
532
533                let len = i - pos;
534
535                // Accept match if long enough and potentially better
536                if len >= self.min_length && len + 2 >= best_len {
537                    let dist_symbol =
538                        if let Some(&sym) = self.special_dist_table.get(&(dist as i32)) {
539                            sym
540                        } else {
541                            self.num_special_distances + dist - 1
542                        };
543                    found_match(len, dist_symbol);
544                    if len > best_len {
545                        best_len = len;
546                    }
547                }
548            }
549
550            chain_length += 1;
551            if chain_length >= self.max_chain_length {
552                break;
553            }
554
555            // Follow chain
556            if self.numzeros >= 3 && best_len > self.numzeros as usize {
557                // Use zero-run chain for efficiency
558                if hashpos == self.chainz[hashpos as usize] {
559                    break;
560                }
561                hashpos = self.chainz[hashpos as usize];
562                if self.zeros[hashpos as usize] != self.numzeros {
563                    break;
564                }
565            } else {
566                // Use regular hash chain
567                if hashpos == self.chain[hashpos as usize] {
568                    break;
569                }
570                hashpos = self.chain[hashpos as usize];
571                if self.val[hashpos as usize] != hashval as i32 {
572                    // Outdated hash value
573                    break;
574                }
575            }
576        }
577    }
578}
579
580/// Apply greedy LZ77 with backward references using hash chains.
581///
582/// This implements libjxl's `ApplyLZ77_LZ77` algorithm which uses hash chains
583/// to find matching sequences at arbitrary distances within a sliding window.
584/// Includes lazy matching to find longer matches at the next position.
585///
586/// Returns `Some((transformed_tokens, params))` if LZ77 is beneficial,
587/// or `None` if the savings are insufficient.
588pub fn apply_lz77_backref(
589    tokens: &[Token],
590    num_contexts: usize,
591    force_huffman: bool,
592    distance_multiplier: i32,
593) -> Option<(Vec<Token>, Lz77Params)> {
594    if tokens.is_empty() {
595        return None;
596    }
597
598    let mut lz77 = Lz77Params::new(num_contexts, force_huffman);
599
600    // Build cost estimator from original tokens
601    let sce = SymbolCostEstimator::new(num_contexts, force_huffman, tokens, &lz77);
602
603    // Compute cumulative bit costs for original stream
604    let mut sym_cost = vec![0.0f32; tokens.len() + 1];
605    for (i, token) in tokens.iter().enumerate() {
606        let e = UintCoder::encode(token.value);
607        let cost = sce.symbol_cost(token.context as usize, e.token as usize) + e.nbits as f32;
608        sym_cost[i + 1] = sym_cost[i] + cost;
609    }
610
611    let mut out = Vec::with_capacity(tokens.len());
612    let mut bit_decrease: f32 = 0.0;
613    let total_symbols = tokens.len();
614
615    let max_distance = tokens.len();
616    let min_length = lz77.min_length as usize;
617    let max_length = tokens.len();
618
619    // Use next power of two as window size
620    let mut window_size = 1usize;
621    while window_size < max_distance && window_size < WINDOW_SIZE {
622        window_size <<= 1;
623    }
624
625    let mut chain = HashChain::new(
626        tokens,
627        window_size,
628        min_length,
629        max_length,
630        distance_multiplier,
631    );
632
633    const MAX_LAZY_MATCH_LEN: usize = 256;
634    let mut already_updated = false;
635
636    let mut i = 0usize;
637    while i < tokens.len() {
638        out.push(tokens[i]);
639
640        if !already_updated {
641            chain.update(i);
642        }
643        already_updated = false;
644
645        let (mut dist_symbol, mut len) = chain.find_match(i, max_distance);
646
647        if len >= min_length {
648            // Try lazy matching: check if next position has a longer match
649            if len < MAX_LAZY_MATCH_LEN && i + 1 < tokens.len() {
650                chain.update(i + 1);
651                already_updated = true;
652                let (dist_symbol2, len2) = chain.find_match(i + 1, max_distance);
653                if len2 > len {
654                    // Use lazy match: emit literal for current position,
655                    // then use match starting at next position
656                    i += 1;
657                    already_updated = false;
658                    len = len2;
659                    dist_symbol = dist_symbol2;
660                    out.push(tokens[i]);
661                }
662            }
663
664            // Compute costs
665            let literal_cost = sym_cost[i + len] - sym_cost[i];
666            let lz77_len = len - min_length;
667
668            // Use empirical cost tables for LZ77 encoding
669            let lz77_cost = len_cost(lz77_len as u32)
670                + dist_cost(dist_symbol as u32)
671                + sce.add_symbol_cost(out.last().unwrap().context as usize);
672
673            if lz77_cost <= literal_cost {
674                // Emit LZ77 match
675                let last_token = out.last_mut().unwrap();
676                last_token.value = lz77_len as u32;
677                last_token.is_lz77_length = true;
678
679                out.push(Token::new(lz77.distance_context, dist_symbol as u32));
680
681                bit_decrease += literal_cost - lz77_cost;
682            } else {
683                // LZ77 not beneficial, emit literals
684                for j in 1..len {
685                    out.push(tokens[i + j]);
686                }
687            }
688
689            // Update hash chain for matched positions
690            if already_updated {
691                chain.update_range(i + 2, len - 2);
692                already_updated = false;
693            } else {
694                chain.update_range(i + 1, len - 1);
695            }
696            i += len - 1;
697        }
698        // Else: literal already pushed
699
700        i += 1;
701    }
702
703    // Only use LZ77 if savings exceed threshold
704    let threshold = total_symbols as f32 * 0.2 + 16.0;
705    #[cfg(feature = "debug-tokens")]
706    eprintln!(
707        "[LZ77-backref] bit_decrease={:.1}, threshold={:.1}, tokens: {} -> {}",
708        bit_decrease,
709        threshold,
710        total_symbols,
711        out.len()
712    );
713    if bit_decrease > threshold {
714        lz77.enabled = true;
715        Some((out, lz77))
716    } else {
717        None
718    }
719}
720
721/// Apply RLE-based LZ77 compression to a token stream.
722///
723/// Scans for runs of consecutive identical values. When a run is long enough
724/// and the LZ77 encoding is cheaper, replaces the run with:
725/// 1. An LZ77 length token (context = original, value = run_len - min_length, is_lz77_length = true)
726/// 2. A distance token (context = distance_context, value = 0 meaning "repeat previous")
727///
728/// Returns `Some((transformed_tokens, params))` if LZ77 is beneficial,
729/// or `None` if the savings are insufficient.
730pub fn apply_lz77_rle(
731    tokens: &[Token],
732    num_contexts: usize,
733    force_huffman: bool,
734) -> Option<(Vec<Token>, Lz77Params)> {
735    if tokens.is_empty() {
736        return None;
737    }
738
739    let mut lz77 = Lz77Params::new(num_contexts, force_huffman);
740
741    // First pass: build cost estimator from the original tokens (no LZ77 tokens yet).
742    // We pass the original tokens to estimate costs, matching libjxl.
743    let sce = SymbolCostEstimator::new(num_contexts, force_huffman, tokens, &lz77);
744
745    // Compute cumulative bit costs for original stream.
746    let mut sym_cost = vec![0.0f32; tokens.len() + 1];
747    for (i, token) in tokens.iter().enumerate() {
748        let e = UintCoder::encode(token.value);
749        let cost = sce.symbol_cost(token.context as usize, e.token as usize) + e.nbits as f32;
750        sym_cost[i + 1] = sym_cost[i] + cost;
751    }
752
753    let mut out = Vec::with_capacity(tokens.len());
754    let mut bit_decrease: f32 = 0.0;
755    let total_symbols = tokens.len();
756
757    let mut i = 0;
758    while i < tokens.len() {
759        // Count consecutive identical values starting from the PREVIOUS token
760        // (matching libjxl: "if (i > 0) { ... in[i+num_to_copy].value != in[i-1].value }")
761        let mut num_to_copy = 0;
762        if i > 0 {
763            let prev_value = tokens[i - 1].value;
764            while i + num_to_copy < tokens.len() && tokens[i + num_to_copy].value == prev_value {
765                num_to_copy += 1;
766            }
767        }
768
769        if num_to_copy == 0 {
770            out.push(tokens[i]);
771            i += 1;
772            continue;
773        }
774
775        // Cost of encoding the run literally
776        let literal_cost = sym_cost[i + num_to_copy] - sym_cost[i];
777
778        // Cost of LZ77 encoding (rough estimate matching libjxl)
779        let lz77_cost = if num_to_copy >= lz77.min_length as usize {
780            let lz77_len = num_to_copy - lz77.min_length as usize;
781            // CeilLog2Nonzero(lz77_len + 1) + 1 (for distance)
782            ceil_log2_nonzero((lz77_len + 1) as u32) as f32 + 1.0
783        } else {
784            0.0
785        };
786
787        if num_to_copy < lz77.min_length as usize || literal_cost <= lz77_cost {
788            // Not worth encoding as LZ77, emit literal tokens
789            for j in 0..num_to_copy {
790                out.push(tokens[i + j]);
791            }
792            i += num_to_copy;
793            continue;
794        }
795
796        // Emit LZ77 length token
797        let lz77_len = (num_to_copy - lz77.min_length as usize) as u32;
798        out.push(Token::lz77_length(tokens[i].context, lz77_len));
799
800        // Emit distance token (distance_symbol=0 means distance 1 = "repeat previous",
801        // since num_special_distances=0 for RLE)
802        out.push(Token::new(lz77.distance_context, 0));
803
804        bit_decrease += literal_cost - lz77_cost;
805        i += num_to_copy;
806    }
807
808    // Only use LZ77 if savings exceed threshold (matching libjxl)
809    let threshold = total_symbols as f32 * 0.2 + 16.0;
810    #[cfg(feature = "debug-tokens")]
811    eprintln!(
812        "[LZ77] bit_decrease={:.1}, threshold={:.1}, tokens: {} -> {}",
813        bit_decrease,
814        threshold,
815        total_symbols,
816        out.len()
817    );
818    if bit_decrease > threshold {
819        lz77.enabled = true;
820        Some((out, lz77))
821    } else {
822        None
823    }
824}
825
826/// CeilLog2Nonzero matching libjxl's implementation.
827fn ceil_log2_nonzero(x: u32) -> u32 {
828    debug_assert!(x > 0);
829    let floor = 31 - x.leading_zeros();
830    if x.is_power_of_two() {
831        floor
832    } else {
833        floor + 1
834    }
835}
836
837/// LZ77 method selection.
838#[derive(Debug, Clone, Copy, PartialEq, Eq, Default)]
839pub enum Lz77Method {
840    /// RLE-only: only matches consecutive identical values (distance = 1).
841    /// Fast but limited compression on photographic content.
842    #[default]
843    Rle,
844    /// Full backward references with hash chains.
845    /// Finds matches at arbitrary distances within a sliding window.
846    /// 1-3% better compression on photos, slower.
847    Greedy,
848}
849
850/// Apply LZ77 compression using the specified method.
851///
852/// - `Lz77Method::Rle`: RLE-only (fast, limited compression)
853/// - `Lz77Method::Greedy`: Hash chain backward references (slower, better compression)
854///
855/// For photographic content, `Greedy` typically provides 1-3% additional compression
856/// over RLE-only. For graphics/text with runs of identical values, RLE may be sufficient.
857///
858/// Returns `Some((transformed_tokens, params))` if LZ77 is beneficial,
859/// or `None` if the savings are insufficient.
860pub fn apply_lz77(
861    tokens: &[Token],
862    num_contexts: usize,
863    force_huffman: bool,
864    method: Lz77Method,
865    distance_multiplier: i32,
866) -> Option<(Vec<Token>, Lz77Params)> {
867    match method {
868        Lz77Method::Rle => apply_lz77_rle(tokens, num_contexts, force_huffman),
869        Lz77Method::Greedy => {
870            apply_lz77_backref(tokens, num_contexts, force_huffman, distance_multiplier)
871        }
872    }
873}
874
875/// Try both LZ77 methods and return the one with better compression.
876///
877/// This is useful when you want the best compression regardless of speed.
878/// Returns the method that produces fewer output tokens, or None if neither
879/// method provides sufficient savings.
880#[allow(dead_code)] // Utility function for advanced users
881pub fn apply_lz77_best(
882    tokens: &[Token],
883    num_contexts: usize,
884    force_huffman: bool,
885    distance_multiplier: i32,
886) -> Option<(Vec<Token>, Lz77Params)> {
887    let rle_result = apply_lz77_rle(tokens, num_contexts, force_huffman);
888    let backref_result =
889        apply_lz77_backref(tokens, num_contexts, force_huffman, distance_multiplier);
890
891    match (&rle_result, &backref_result) {
892        (Some((rle_tokens, _)), Some((backref_tokens, _))) => {
893            // Return whichever produces fewer tokens
894            if backref_tokens.len() <= rle_tokens.len() {
895                backref_result
896            } else {
897                rle_result
898            }
899        }
900        (Some(_), None) => rle_result,
901        (None, Some(_)) => backref_result,
902        (None, None) => None,
903    }
904}
905
906#[cfg(test)]
907mod tests {
908    use super::*;
909
910    #[test]
911    fn test_ceil_log2_nonzero() {
912        assert_eq!(ceil_log2_nonzero(1), 0);
913        assert_eq!(ceil_log2_nonzero(2), 1);
914        assert_eq!(ceil_log2_nonzero(3), 2);
915        assert_eq!(ceil_log2_nonzero(4), 2);
916        assert_eq!(ceil_log2_nonzero(5), 3);
917        assert_eq!(ceil_log2_nonzero(8), 3);
918        assert_eq!(ceil_log2_nonzero(9), 4);
919    }
920
921    #[test]
922    fn test_no_rle_on_short_stream() {
923        // Very short streams shouldn't trigger LZ77
924        let tokens = vec![Token::new(0, 5), Token::new(0, 5), Token::new(0, 5)];
925        assert!(apply_lz77_rle(&tokens, 1, false).is_none());
926    }
927
928    #[test]
929    fn test_rle_on_long_run() {
930        // Long run of identical values should trigger LZ77
931        let mut tokens = Vec::new();
932        // Need a "previous" token, then a long run of identical values
933        tokens.push(Token::new(0, 5));
934        for _ in 0..200 {
935            tokens.push(Token::new(0, 5));
936        }
937
938        let result = apply_lz77_rle(&tokens, 1, false);
939        if let Some((lz77_tokens, params)) = result {
940            assert!(params.enabled);
941            // Should be much shorter than the original
942            assert!(lz77_tokens.len() < tokens.len());
943            // Should contain at least one LZ77 length token
944            assert!(lz77_tokens.iter().any(|t| t.is_lz77_length));
945        }
946        // If None, that's OK — the threshold might not be met for this particular cost estimate
947    }
948
949    #[test]
950    fn test_rle_preserves_non_runs() {
951        // Mixed content: some runs, some non-runs
952        let mut tokens = Vec::new();
953        // Non-repeating prefix
954        for i in 0..10 {
955            tokens.push(Token::new(0, i));
956        }
957        // Long run
958        for _ in 0..100 {
959            tokens.push(Token::new(0, 42));
960        }
961        // Non-repeating suffix
962        for i in 0..10 {
963            tokens.push(Token::new(0, i + 100));
964        }
965
966        if let Some((lz77_tokens, params)) = apply_lz77_rle(&tokens, 1, false) {
967            assert!(params.enabled);
968            assert!(lz77_tokens.len() < tokens.len());
969            // The first token should be preserved literally
970            assert_eq!(lz77_tokens[0].value, 0);
971            assert!(!lz77_tokens[0].is_lz77_length);
972        }
973    }
974
975    #[test]
976    fn test_empty_stream() {
977        assert!(apply_lz77_rle(&[], 1, false).is_none());
978    }
979
980    // Tests for backward-reference LZ77
981
982    #[test]
983    fn test_backref_empty_stream() {
984        assert!(apply_lz77_backref(&[], 1, false, 0).is_none());
985    }
986
987    #[test]
988    fn test_backref_short_stream() {
989        // Very short streams shouldn't trigger LZ77
990        let tokens = vec![Token::new(0, 5), Token::new(0, 5), Token::new(0, 5)];
991        assert!(apply_lz77_backref(&tokens, 1, false, 0).is_none());
992    }
993
994    #[test]
995    fn test_backref_on_repeating_pattern() {
996        // Pattern that repeats at distance > 1 (not just RLE)
997        // Pattern: [A, B, C, A, B, C, A, B, C, ...]
998        let mut tokens = Vec::new();
999        for _ in 0..100 {
1000            tokens.push(Token::new(0, 10));
1001            tokens.push(Token::new(0, 20));
1002            tokens.push(Token::new(0, 30));
1003        }
1004
1005        let result = apply_lz77_backref(&tokens, 1, false, 0);
1006        if let Some((lz77_tokens, params)) = result {
1007            assert!(params.enabled);
1008            // Should be significantly shorter due to backward references
1009            assert!(
1010                lz77_tokens.len() < tokens.len(),
1011                "backref should compress pattern: {} vs {}",
1012                lz77_tokens.len(),
1013                tokens.len()
1014            );
1015            // Should have LZ77 length tokens
1016            assert!(lz77_tokens.iter().any(|t| t.is_lz77_length));
1017        }
1018    }
1019
1020    #[test]
1021    fn test_backref_finds_longer_matches_than_rle() {
1022        // Pattern where backref can find matches that RLE cannot
1023        // [1, 2, 3, 4, 5, 1, 2, 3, 4, 5, 1, 2, 3, 4, 5, ...]
1024        let mut tokens = Vec::new();
1025        for _ in 0..50 {
1026            for j in 1..=5 {
1027                tokens.push(Token::new(0, j));
1028            }
1029        }
1030
1031        let rle_result = apply_lz77_rle(&tokens, 1, false);
1032        let backref_result = apply_lz77_backref(&tokens, 1, false, 0);
1033
1034        // RLE should not find matches (no consecutive identical values)
1035        // Backref should find matches at distance 5
1036        match (&rle_result, &backref_result) {
1037            (None, Some((backref_tokens, _))) => {
1038                // This is the expected case: RLE finds nothing, backref does
1039                assert!(backref_tokens.len() < tokens.len());
1040            }
1041            (Some((rle_tokens, _)), Some((backref_tokens, _))) => {
1042                // If both activate, backref should do better or equal
1043                assert!(backref_tokens.len() <= rle_tokens.len());
1044            }
1045            _ => {
1046                // Either both fail (acceptable for small patterns) or both succeed
1047            }
1048        }
1049    }
1050
1051    #[test]
1052    fn test_backref_with_distance_multiplier() {
1053        // Test that special distance codes work with distance multiplier
1054        // When multiplier is non-zero, distances like image_width are encoded more efficiently
1055        let mut tokens = Vec::new();
1056        let image_width = 64;
1057
1058        // Create pattern that repeats at image_width distance (previous row)
1059        for _row in 0..20 {
1060            for col in 0..image_width {
1061                // Same value for same column across rows
1062                tokens.push(Token::new(0, (col % 16) as u32));
1063            }
1064        }
1065
1066        let _result_no_mult = apply_lz77_backref(&tokens, 1, false, 0);
1067        let result_with_mult = apply_lz77_backref(&tokens, 1, false, image_width);
1068
1069        // Both should find matches; with multiplier might be more efficient
1070        // but the main test is that it doesn't crash and produces valid output
1071        if let Some((tokens_mult, params)) = result_with_mult {
1072            assert!(params.enabled);
1073            assert!(tokens_mult.len() < tokens.len());
1074        }
1075    }
1076
1077    #[test]
1078    fn test_special_distance() {
1079        // Test special distance calculation
1080        // kSpecialDistances[0] = [0, 1] -> distance = 0 + 1*multiplier = multiplier
1081        assert_eq!(special_distance(0, 64), 64);
1082        // kSpecialDistances[1] = [1, 0] -> distance = 1 + 0*multiplier = 1
1083        assert_eq!(special_distance(1, 64), 1);
1084        // kSpecialDistances[2] = [1, 1] -> distance = 1 + 1*multiplier = 65
1085        assert_eq!(special_distance(2, 64), 65);
1086        // kSpecialDistances[3] = [-1, 1] -> distance = -1 + 1*multiplier = 63
1087        assert_eq!(special_distance(3, 64), 63);
1088    }
1089
1090    #[test]
1091    fn test_len_cost() {
1092        // Verify len_cost doesn't panic on various inputs
1093        for len in 0..1000 {
1094            let cost = len_cost(len);
1095            assert!(cost >= 0.0, "len_cost({}) should be non-negative", len);
1096            assert!(cost < 100.0, "len_cost({}) should be reasonable", len);
1097        }
1098    }
1099
1100    #[test]
1101    fn test_dist_cost() {
1102        // Verify dist_cost doesn't panic on various inputs
1103        for dist in 0..10000 {
1104            let cost = dist_cost(dist);
1105            assert!(cost >= 0.0, "dist_cost({}) should be non-negative", dist);
1106            assert!(cost < 100.0, "dist_cost({}) should be reasonable", dist);
1107        }
1108    }
1109
1110    #[test]
1111    fn test_apply_lz77_method_enum() {
1112        let mut tokens = Vec::new();
1113        tokens.push(Token::new(0, 5));
1114        for _ in 0..200 {
1115            tokens.push(Token::new(0, 5));
1116        }
1117
1118        // Test RLE method
1119        let rle_result = apply_lz77(&tokens, 1, false, Lz77Method::Rle, 0);
1120        if let Some((_, params)) = &rle_result {
1121            assert!(params.enabled);
1122        }
1123
1124        // Test Greedy method
1125        let greedy_result = apply_lz77(&tokens, 1, false, Lz77Method::Greedy, 0);
1126        if let Some((_, params)) = &greedy_result {
1127            assert!(params.enabled);
1128        }
1129    }
1130
1131    #[test]
1132    fn test_apply_lz77_best() {
1133        // Pattern where backref should do better
1134        let mut tokens = Vec::new();
1135        for _ in 0..50 {
1136            for j in 1..=10 {
1137                tokens.push(Token::new(0, j));
1138            }
1139        }
1140
1141        let best_result = apply_lz77_best(&tokens, 1, false, 0);
1142        // Should pick the better method (likely backref for this pattern)
1143        if let Some((best_tokens, params)) = best_result {
1144            assert!(params.enabled);
1145            assert!(best_tokens.len() < tokens.len());
1146        }
1147    }
1148
1149    #[test]
1150    fn test_hash_chain_basic() {
1151        // Test hash chain finds simple matches
1152        let tokens = vec![
1153            Token::new(0, 10),
1154            Token::new(0, 20),
1155            Token::new(0, 30),
1156            Token::new(0, 40), // Different sequence
1157            Token::new(0, 10),
1158            Token::new(0, 20),
1159            Token::new(0, 30), // Repeats tokens 0-2
1160        ];
1161
1162        let mut chain = HashChain::new(&tokens, 16, 3, 100, 0);
1163        // Update chain for all positions
1164        for i in 0..tokens.len() {
1165            chain.update(i);
1166        }
1167
1168        // At position 4, should find match at position 0 (distance 4)
1169        let (dist_symbol, len) = chain.find_match(4, 10);
1170        assert!(len >= 3, "should find match of length >= 3, got {}", len);
1171        // dist_symbol should encode distance 4 (no special distances with multiplier=0)
1172        // Special distances: 0 entries since multiplier=0
1173        // So dist_symbol = num_special_distances + dist - 1 = 0 + 4 - 1 = 3
1174        assert_eq!(dist_symbol, 3, "distance symbol for dist=4 should be 3");
1175    }
1176}