Skip to main content

rust_zstd/
compress.rs

1//! Zstandard frame compressor.
2//!
3//! Produces valid zstd frames decompressible by any standard decoder.
4//! Uses greedy hash-based matching (equivalent to zstd level 1).
5
6use super::constants::*;
7
8/// Compress data into a zstd frame.
9///
10/// `level` controls the compression strategy:
11/// - 0: no compression (raw blocks, fastest)
12/// - 1-2: greedy matching (fast, hash_log=14)
13/// - 3-5: lazy matching (better ratio, hash_log=15)
14/// - 6-8: lazy matching + deeper search (hash_log=16)
15/// - 9-11: lazy matching + deepest search (hash_log=17)
16///
17/// Returns a valid zstd frame decompressible by any conformant decoder.
18pub fn compress(data: &[u8], level: i32) -> Vec<u8> {
19    let mut out = Vec::with_capacity(data.len() + 64);
20    write_frame_header(&mut out, data.len() as u64);
21
22    if data.is_empty() {
23        write_raw_block(&mut out, &[], true);
24        return out;
25    }
26
27    if level <= 0 {
28        // Level 0: raw/RLE blocks
29        let blocks: Vec<&[u8]> = data.chunks(ZSTD_BLOCKSIZE_MAX).collect();
30        let n_blocks = blocks.len();
31        for (i, block) in blocks.iter().enumerate() {
32            if is_rle_block(block) {
33                write_rle_block(&mut out, block[0], block.len(), i == n_blocks - 1);
34            } else {
35                write_raw_block(&mut out, block, i == n_blocks - 1);
36            }
37        }
38        return out;
39    }
40
41    // Level 1+: run match finder on entire input (cross-block window), then
42    // split sequences into blocks for encoding.
43    // Try both fast (fewer sequences, better for high-entropy) and lazy (more
44    // matches, better for structured data) and keep the one producing smaller output.
45    let params = MatchParams::from_level(level);
46    let mut all_sequences = find_matches(data, &params);
47
48    // For level 1-2: try both fast and lazy strategies, pick smaller
49    if params.lazy_depth == 0 {
50        let lazy_params = MatchParams {
51            lazy_depth: 1,
52            hash_log: 17,
53            hash_bytes: 5,
54            search_depth: 8,
55        };
56
57        #[cfg(feature = "parallel")]
58        let lazy_seqs = {
59            // Run lazy match finder in parallel with the fast one already computed
60            let lazy_params_clone = lazy_params;
61            rayon::spawn(|| {}); // warm up thread pool
62            let (_, lazy) = rayon::join(
63                || (), // fast already done above
64                || find_matches_lazy(data, &lazy_params_clone),
65            );
66            lazy
67        };
68        #[cfg(not(feature = "parallel"))]
69        let lazy_seqs = find_matches_lazy(data, &lazy_params);
70
71        let fast_enc = resolve_repeat_offsets(&all_sequences);
72        let lazy_enc = resolve_repeat_offsets(&lazy_seqs);
73        let fast_cost = estimate_seq_cost(&all_sequences, &fast_enc);
74        let lazy_cost = estimate_seq_cost(&lazy_seqs, &lazy_enc);
75        if lazy_cost < fast_cost {
76            all_sequences = lazy_seqs;
77        }
78    }
79
80    // Split oversized sequences first, then resolve repeat offsets.
81    let all_sequences = split_long_raw_sequences(all_sequences);
82    let all_encoded = resolve_repeat_offsets(&all_sequences);
83
84    // Split sequences into blocks. Each block's decompressed content must not
85    // exceed ZSTD_BLOCKSIZE_MAX (128 KiB). Decompressed size = sum(ll + ml) for
86    // all sequences in the block + trailing literals (for last block in range).
87    let mut block_ranges: Vec<(usize, usize, usize)> = Vec::new(); // (seq_start, seq_end, data_end)
88    let mut seq_start = 0usize;
89    let mut data_pos = 0usize;
90    let mut block_output = 0usize;
91
92    for (i, raw) in all_sequences.iter().enumerate() {
93        let seq_output = raw.ll as usize + raw.ml as usize;
94        if block_output + seq_output > ZSTD_BLOCKSIZE_MAX && i > seq_start {
95            block_ranges.push((seq_start, i, data_pos));
96            seq_start = i;
97            block_output = 0;
98        }
99        block_output += seq_output;
100        data_pos += seq_output;
101    }
102
103    // Final block: includes remaining sequences + trailing literals.
104    // Ensure decompressed size doesn't exceed BLOCKSIZE_MAX.
105    let trailing_lits = data.len() - data_pos;
106    if block_output + trailing_lits > ZSTD_BLOCKSIZE_MAX && block_output > 0 {
107        // Sequences go in one block (without trailing lits as block data)
108        block_ranges.push((seq_start, all_sequences.len(), data_pos));
109        // Trailing literals become a separate empty-sequences block
110        block_ranges.push((all_sequences.len(), all_sequences.len(), data.len()));
111    } else {
112        block_ranges.push((seq_start, all_sequences.len(), data.len()));
113    }
114
115    let n_blocks = block_ranges.len();
116
117    // Precompute d_start for each block (cumulative sum of ll+ml)
118    let mut d_starts = Vec::with_capacity(n_blocks);
119    for &(s_start, _, _) in &block_ranges {
120        if s_start == 0 {
121            d_starts.push(0);
122        } else {
123            let mut p = 0usize;
124            for s in &all_sequences[..s_start] {
125                p += s.ll as usize + s.ml as usize;
126            }
127            d_starts.push(p);
128        }
129    }
130
131    // Encode each block's content (literals + sequences) — parallelizable
132    #[cfg(feature = "parallel")]
133    let encode_block_content = |bi: usize| -> (Vec<u8>, Vec<u8>) {
134        let (s_start, s_end, d_end) = block_ranges[bi];
135        let block_seqs = &all_encoded[s_start..s_end];
136        let raw_seqs = &all_sequences[s_start..s_end];
137        let d_start = d_starts[bi];
138        let block_data = &data[d_start..d_end];
139
140        if block_seqs.is_empty() {
141            return (vec![], block_data.to_vec()); // raw/rle — return block_data
142        }
143
144        let mut literals = Vec::with_capacity(block_data.len());
145        let mut pos = 0usize;
146        for seq in raw_seqs {
147            literals.extend_from_slice(&block_data[pos..pos + seq.ll as usize]);
148            pos += seq.ll as usize + seq.ml as usize;
149        }
150        literals.extend_from_slice(&block_data[pos..]);
151
152        let mut block = Vec::with_capacity(block_data.len());
153        let mut used_huf = false;
154        if !literals.is_empty() && literals.iter().all(|&b| b == literals[0]) {
155            encode_literals_rle(&mut block, literals[0], literals.len());
156            used_huf = true;
157        } else if literals.len() >= 64 {
158            if let Some(huf) = encode_literals_huffman(&literals) {
159                if huf.len() < literals.len() {
160                    block.extend_from_slice(&huf);
161                    used_huf = true;
162                }
163            }
164        }
165        if !used_huf {
166            encode_literals_raw(&mut block, &literals);
167        }
168        encode_sequences_section(&mut block, block_seqs);
169        (block, block_data.to_vec())
170    };
171
172    // Encode blocks
173    // Parallel: blocks encoded independently (no Treeless)
174    // Sequential: blocks encoded with Treeless reuse from previous block
175    #[cfg(feature = "parallel")]
176    let encoded_blocks: Vec<(Vec<u8>, Vec<u8>)> = {
177        use rayon::prelude::*;
178        (0..n_blocks)
179            .into_par_iter()
180            .map(&encode_block_content)
181            .collect()
182    };
183
184    #[cfg(not(feature = "parallel"))]
185    let encoded_blocks: Vec<(Vec<u8>, Vec<u8>)> = {
186        let mut prev_huf: Option<([(u32, u8); 256], u8)> = None;
187        let mut results = Vec::with_capacity(n_blocks);
188        for bi in 0..n_blocks {
189            let (s_start, s_end, d_end) = block_ranges[bi];
190            let block_seqs = &all_encoded[s_start..s_end];
191            let raw_seqs = &all_sequences[s_start..s_end];
192            let d_start = d_starts[bi];
193            let block_data = &data[d_start..d_end];
194
195            if block_seqs.is_empty() {
196                results.push((vec![], block_data.to_vec()));
197                continue;
198            }
199
200            let mut literals = Vec::with_capacity(block_data.len());
201            let mut pos = 0usize;
202            for seq in raw_seqs {
203                literals.extend_from_slice(&block_data[pos..pos + seq.ll as usize]);
204                pos += seq.ll as usize + seq.ml as usize;
205            }
206            literals.extend_from_slice(&block_data[pos..]);
207
208            let mut block = Vec::with_capacity(block_data.len());
209            let mut used_huf = false;
210
211            if !literals.is_empty() && literals.iter().all(|&b| b == literals[0]) {
212                encode_literals_rle(&mut block, literals[0], literals.len());
213                used_huf = true;
214            } else if literals.len() >= 64 {
215                let new_result = encode_literals_huffman(&literals);
216
217                // Treeless: reuse previous Huffman tree if it covers all symbols
218                let treeless_result = if let Some((prev_codes, _)) = &prev_huf {
219                    if literals.iter().all(|&b| prev_codes[b as usize].1 > 0) {
220                        encode_literals_treeless(&literals, prev_codes)
221                    } else {
222                        None
223                    }
224                } else {
225                    None
226                };
227
228                let mut used_treeless = false;
229                match (&new_result, &treeless_result) {
230                    (Some(_), Some(te))
231                        if te.len() <= new_result.as_ref().unwrap().len()
232                            && te.len() < literals.len() =>
233                    {
234                        block.extend_from_slice(te);
235                        used_huf = true;
236                        used_treeless = true;
237                    }
238                    (Some(ne), _) if ne.len() < literals.len() => {
239                        block.extend_from_slice(ne);
240                        used_huf = true;
241                    }
242                    (None, Some(te)) if te.len() < literals.len() => {
243                        block.extend_from_slice(te);
244                        used_huf = true;
245                        used_treeless = true;
246                    }
247                    _ => {}
248                }
249
250                // Update prev_huf ONLY when a NEW tree was used (Compressed, not Treeless).
251                // When Treeless is used, the decoder keeps the previous tree — so must we.
252                if used_huf && !used_treeless {
253                    let mut counts = [0u32; 256];
254                    let mut ms = 0u8;
255                    for &b in &literals {
256                        counts[b as usize] += 1;
257                        if b > ms {
258                            ms = b;
259                        }
260                    }
261                    if let Some((codes, mb)) = build_huffman_codes(&counts, ms as usize) {
262                        prev_huf = Some((codes, mb));
263                    }
264                }
265            }
266
267            if !used_huf {
268                encode_literals_raw(&mut block, &literals);
269            }
270            encode_sequences_section(&mut block, block_seqs);
271            results.push((block, block_data.to_vec()));
272        }
273        results
274    };
275
276    // Write blocks to output (sequential — must maintain order)
277    for (bi, (block, block_data)) in encoded_blocks.iter().enumerate() {
278        let is_last = bi == n_blocks - 1;
279
280        if block.is_empty() {
281            // No sequences — raw/rle sub-blocks
282            let mut rem = &block_data[..];
283            while !rem.is_empty() {
284                let sz = std::cmp::min(rem.len(), ZSTD_BLOCKSIZE_MAX);
285                let chunk = &rem[..sz];
286                let last = is_last && sz == rem.len();
287                if is_rle_block(chunk) {
288                    write_rle_block(&mut out, chunk[0], chunk.len(), last);
289                } else {
290                    write_raw_block(&mut out, chunk, last);
291                }
292                rem = &rem[sz..];
293            }
294            continue;
295        }
296
297        if block_data.len() <= ZSTD_BLOCKSIZE_MAX {
298            if is_rle_block(block_data) {
299                write_rle_block(&mut out, block_data[0], block_data.len(), is_last);
300            } else if block.len() < block_data.len() && block.len() <= ZSTD_BLOCKSIZE_MAX {
301                write_compressed_block(&mut out, block, is_last);
302            } else {
303                write_raw_block(&mut out, block_data, is_last);
304            }
305        } else if block.len() < block_data.len() && block.len() <= ZSTD_BLOCKSIZE_MAX {
306            write_compressed_block(&mut out, block, is_last);
307        } else {
308            let mut remaining = &block_data[..];
309            while !remaining.is_empty() {
310                let sz = std::cmp::min(remaining.len(), ZSTD_BLOCKSIZE_MAX);
311                let chunk = &remaining[..sz];
312                let last = is_last && sz == remaining.len();
313                if is_rle_block(chunk) {
314                    write_rle_block(&mut out, chunk[0], chunk.len(), last);
315                } else {
316                    write_raw_block(&mut out, chunk, last);
317                }
318                remaining = &remaining[sz..];
319            }
320        }
321    }
322
323    out
324}
325
326/// Convenience wrapper.
327pub fn compress_to_vec(data: &[u8]) -> Vec<u8> {
328    compress(data, 1)
329}
330
331// =========================================================================
332// Frame header
333// =========================================================================
334
335fn write_frame_header(out: &mut Vec<u8>, content_size: u64) {
336    // Magic number (LE)
337    out.extend_from_slice(&ZSTD_MAGIC.to_le_bytes());
338
339    // Frame_Header_Descriptor:
340    // bit 7-6: Frame_Content_Size_flag (determines FCS field size)
341    // bit 5:   Single_Segment_flag (1 = no Window_Descriptor)
342    // bit 4:   unused
343    // bit 3:   reserved
344    // bit 2:   Content_Checksum_flag (0 = no checksum)
345    // bit 1-0: Dictionary_ID_flag (0 = no dict)
346
347    let (fcs_flag, fcs_bytes) = if content_size <= 255 {
348        (0u8, 1) // 1 byte FCS (but flag 0 means 0 bytes normally...)
349    } else if content_size <= 65535 + 256 {
350        (1u8, 2) // 2 bytes
351    } else if content_size <= u32::MAX as u64 {
352        (2u8, 4)
353    } else {
354        (3u8, 8)
355    };
356
357    // For single-segment, FCS_flag=0 means 1 byte (when single_segment=1)
358    let single_segment = 1u8; // always single-segment for simplicity
359    let descriptor = (fcs_flag << 6) | (single_segment << 5);
360    out.push(descriptor);
361
362    // No Window_Descriptor (single_segment = 1)
363
364    // Frame_Content_Size
365    match fcs_bytes {
366        1 => out.push(content_size as u8),
367        2 => out.extend_from_slice(&((content_size - 256) as u16).to_le_bytes()),
368        4 => out.extend_from_slice(&(content_size as u32).to_le_bytes()),
369        8 => out.extend_from_slice(&content_size.to_le_bytes()),
370        _ => {}
371    }
372}
373
374// =========================================================================
375// Block writing
376// =========================================================================
377
378fn write_raw_block(out: &mut Vec<u8>, data: &[u8], is_last: bool) {
379    let header = (is_last as u32) | ((BLOCK_TYPE_RAW as u32) << 1) | ((data.len() as u32) << 3);
380    out.extend_from_slice(&header.to_le_bytes()[..3]);
381    out.extend_from_slice(data);
382}
383
384fn write_rle_block(out: &mut Vec<u8>, byte: u8, repeat_count: usize, is_last: bool) {
385    let header = (is_last as u32) | ((BLOCK_TYPE_RLE as u32) << 1) | ((repeat_count as u32) << 3);
386    out.extend_from_slice(&header.to_le_bytes()[..3]);
387    out.push(byte);
388}
389
390/// Check if all bytes in a slice are identical (RLE candidate).
391fn is_rle_block(data: &[u8]) -> bool {
392    if data.is_empty() {
393        return false;
394    }
395    let first = data[0];
396    data.iter().all(|&b| b == first)
397}
398
399fn write_compressed_block(out: &mut Vec<u8>, compressed: &[u8], is_last: bool) {
400    let header =
401        (is_last as u32) | ((BLOCK_TYPE_COMPRESSED as u32) << 1) | ((compressed.len() as u32) << 3);
402    out.extend_from_slice(&header.to_le_bytes()[..3]);
403    out.extend_from_slice(compressed);
404}
405
406// =========================================================================
407// Block compression (greedy matching + raw literals + predefined FSE)
408// =========================================================================
409
410/// A sequence: (literal_length, offset_value, match_length).
411/// `off` is the raw back-reference distance.
412/// After repeat offset resolution, it becomes an "offset value" for encoding.
413struct Sequence {
414    ll: u32,
415    off: u32, // raw back-reference distance
416    ml: u32,  // actual match length (>= ZSTD_MINMATCH)
417}
418
419/// Offset value after repeat-offset resolution.
420/// In zstd, offset_value 1/2/3 = repeat offsets, >3 = new offset + 3.
421struct EncodedSequence {
422    ll: u32,
423    of_value: u32, // offset value for encoding (1..3 = repcode, >3 = new)
424    ml: u32,
425}
426
427/// Match finder parameters, derived from compression level.
428struct MatchParams {
429    hash_log: u32,
430    hash_bytes: usize, // 4, 5, 6, or 7 — number of bytes used by hash function
431    lazy_depth: u32,   // 0=greedy, 1=lazy, 2=lazy2
432    search_depth: u32, // hash chain search depth
433}
434
435impl MatchParams {
436    /// Parameters aligned with C zstd compression parameters.
437    /// C zstd level 1: hashLog=14, minMatch=7, strategy=fast (7-byte hash)
438    /// C zstd level 3: hashLog=17, minMatch=6, strategy=dfast (6-byte hash)
439    /// C zstd level 7+: hashLog=19+, minMatch=5, strategy=lazy/lazy2
440    ///
441    /// Key: longer hash → fewer but higher-quality matches → less sequence
442    /// overhead per byte. C level 1's 7-byte hash intentionally skips short matches.
443    fn from_level(level: i32) -> Self {
444        match level {
445            0..=2 => Self {
446                hash_log: 14,  // C zstd level 1 uses hashLog=14
447                hash_bytes: 7, // 7-byte hash like C zstd level 1 (minMatch=7)
448                lazy_depth: 0,
449                search_depth: 4,
450            },
451            3..=5 => Self {
452                hash_log: 18,
453                hash_bytes: 5,
454                lazy_depth: 1,
455                search_depth: 16,
456            },
457            6..=8 => Self {
458                hash_log: 19,
459                hash_bytes: 5,
460                lazy_depth: 1,
461                search_depth: 64,
462            },
463            _ => Self {
464                hash_log: 20,
465                hash_bytes: 5,
466                lazy_depth: 1,
467                search_depth: 256,
468            },
469        }
470    }
471}
472
473/// Resolve repeat offsets per zstd spec (RFC 8878 §3.1.2.5).
474///
475/// Encoder chooses offset_value for each sequence:
476/// - If raw_offset matches a repeat offset → use repcode (1/2/3)
477/// - Otherwise → use raw_offset + 3
478///
479/// After each sequence, the repeat offset table is updated:
480/// - New offset → shift: rep = [new, old_rep0, old_rep1]
481/// - Repeat 1 → no change
482/// - Repeat 2 → rotate: rep = [rep1, rep0, rep2]
483/// - Repeat 3 → rotate: rep = [rep2, rep0, rep1]
484fn resolve_repeat_offsets(sequences: &[Sequence]) -> Vec<EncodedSequence> {
485    let mut rep = [1u32, 4, 8]; // initial repeat offsets
486    let mut out = Vec::with_capacity(sequences.len());
487
488    for seq in sequences {
489        let raw_off = seq.off;
490        let of_value;
491
492        if seq.ll > 0 {
493            // Normal case (ll > 0)
494            if raw_off == rep[0] {
495                of_value = 1;
496                // rep unchanged
497            } else if raw_off == rep[1] {
498                of_value = 2;
499                // rotate: [rep1, rep0, rep2]
500                rep = [rep[1], rep[0], rep[2]];
501            } else if raw_off == rep[2] {
502                of_value = 3;
503                // rotate: [rep2, rep0, rep1]
504                rep = [rep[2], rep[0], rep[1]];
505            } else {
506                of_value = raw_off + 3;
507                // shift: [new, old0, old1]
508                rep = [raw_off, rep[0], rep[1]];
509            }
510        } else {
511            // ll == 0: offsets are shifted by 1
512            // of_value 1 → rep[1], of_value 2 → rep[2], of_value 3 → rep[0]-1
513            if raw_off == rep[1] {
514                of_value = 1;
515                rep = [rep[1], rep[0], rep[2]];
516            } else if raw_off == rep[2] {
517                of_value = 2;
518                rep = [rep[2], rep[0], rep[1]];
519            } else if raw_off == rep[0].wrapping_sub(1) && rep[0] > 1 {
520                of_value = 3;
521                rep = [rep[0] - 1, rep[0], rep[1]];
522            } else {
523                of_value = raw_off + 3;
524                rep = [raw_off, rep[0], rep[1]];
525            }
526        }
527
528        out.push(EncodedSequence {
529            ll: seq.ll,
530            of_value,
531            ml: seq.ml,
532        });
533    }
534
535    out
536}
537
538/// Split raw sequences where ll+ml > BLOCKSIZE_MAX. Called BEFORE resolve_repeat_offsets.
539fn split_long_raw_sequences(sequences: Vec<Sequence>) -> Vec<Sequence> {
540    let needs_split = sequences
541        .iter()
542        .any(|s| s.ll as usize + s.ml as usize > ZSTD_BLOCKSIZE_MAX);
543    if !needs_split {
544        return sequences;
545    }
546
547    let mut out = Vec::with_capacity(sequences.len() + 16);
548    for seq in sequences {
549        let total = seq.ll as usize + seq.ml as usize;
550        if total <= ZSTD_BLOCKSIZE_MAX {
551            out.push(seq);
552        } else {
553            // First chunk: all literals + partial match
554            let max_ml = ZSTD_BLOCKSIZE_MAX.saturating_sub(seq.ll as usize);
555            let ml_first = std::cmp::max(ZSTD_MINMATCH, std::cmp::min(seq.ml as usize, max_ml));
556            out.push(Sequence {
557                ll: seq.ll,
558                off: seq.off,
559                ml: ml_first as u32,
560            });
561            let mut remaining = seq.ml as usize - ml_first;
562            // Continuation: use ll=1 (not ll=0!) to avoid the ll=0 offset shift
563            // in resolve_repeat_offsets. With ll=1, of_value=1 → rep1 = this offset.
564            // We "borrow" 1 byte from the match to use as a literal.
565            while remaining > ZSTD_MINMATCH {
566                let ml = std::cmp::min(remaining - 1, ZSTD_BLOCKSIZE_MAX - 1);
567                out.push(Sequence {
568                    ll: 1,
569                    off: seq.off,
570                    ml: ml as u32,
571                });
572                remaining -= ml + 1; // 1 literal + ml match
573            }
574        }
575    }
576    out
577}
578
579/// Split oversized sequences (ll+ml > BLOCKSIZE_MAX) after resolve_repeat_offsets.
580/// Continuation sequences get of_value=1 (repeat the same offset that was just used).
581/// For ll=0 continuations, of_value=1 means "rep2" in zstd spec, but since the previous
582/// sequence used the same offset, rep1=offset, so of_value=1 with ll>0 means rep1.
583/// When ll=0, the spec shifts: of_value=1 → rep2. So we need of_value=2 for ll=0
584/// to get rep1 when ll=0... Actually it's simpler: the first continuation has ll=0
585/// and wants the SAME offset. With ll=0, of_value=1 → rep2. But rep2=rep1_before
586/// since the previous seq made rep1=offset. So of_value=1 with ll=0 gives rep2=old_rep1.
587/// That's wrong. We need of_value that gives rep1. With ll=0, of_value for rep1 is...
588/// not directly available. The trick: make ll=1 (one literal byte) for continuation,
589/// then of_value=1 → rep1. But that changes the data.
590///
591/// Simplest correct approach: don't split at the encoded level. Instead, limit match
592/// length during match finding.
593fn find_matches(data: &[u8], params: &MatchParams) -> Vec<Sequence> {
594    if params.lazy_depth == 0 {
595        find_matches_fast(data, params)
596    } else {
597        find_matches_lazy(data, params)
598    }
599}
600
601/// Estimate the encoded cost of a set of sequences (without actually encoding).
602/// Uses: literals_cost (Huffman ~5 bits/byte) + sequence_count * overhead + match_savings.
603fn estimate_seq_cost(raw_seqs: &[Sequence], _enc_seqs: &[EncodedSequence]) -> u64 {
604    let mut literal_bytes = 0u64;
605    for seq in raw_seqs {
606        literal_bytes += seq.ll as u64;
607    }
608    raw_seqs.len() as u64 * 20 + literal_bytes * 5
609}
610
611/// Greedy fast match finder modeled on ZSTD_compressBlock_fast_noDict_generic.
612/// Single hash table lookup (no chains), rep-code priority, step increment.
613fn find_matches_fast(data: &[u8], params: &MatchParams) -> Vec<Sequence> {
614    const MIN_INPUT: usize = 8;
615    if data.len() < MIN_INPUT {
616        return vec![];
617    }
618
619    let hlog = params.hash_log;
620    let hash_size = 1usize << hlog;
621    let mls = params.hash_bytes; // minimum match length for hash
622    let mut ht = vec![0u32; hash_size]; // hash table: hash → position
623    let mut sequences = Vec::new();
624
625    let ilimit = data.len() - MIN_INPUT;
626    let mut anchor = 0usize;
627    let mut ip0 = 0usize;
628
629    let mut rep1 = 0u32; // most recent offset
630    let mut rep2 = 0u32; // second most recent offset
631
632    const K_SEARCH_STRENGTH: u32 = 8;
633    const K_STEP_INCR: usize = 1 << (K_SEARCH_STRENGTH - 1); // 128
634
635    'outer: loop {
636        let mut step: usize = 2; // initial step between search pairs
637        let mut next_step = ip0 + K_STEP_INCR;
638
639        let mut ip1 = ip0 + 1;
640
641        if ip1 > ilimit {
642            break;
643        }
644
645        // Pre-hash ip0
646        let mut h0 = hash_n(&data[ip0..], (hash_size - 1) as u32, mls);
647
648        loop {
649            let ip2 = ip0 + step;
650            let ip3 = ip1 + step;
651            if ip3 > ilimit {
652                break 'outer;
653            }
654
655            // Get match candidate from hash table for ip0
656            let match_idx0 = ht[h0] as usize;
657
658            // Hash ip1, write ip0 to hash table
659            let h1 = hash_n(&data[ip1..], (hash_size - 1) as u32, mls);
660            ht[h0] = ip0 as u32;
661
662            // === Rep-code check at ip2 (before hash match) ===
663            if rep1 > 0 && ip2 >= rep1 as usize {
664                let rep_cand = ip2 - rep1 as usize;
665                if rep_cand + 4 <= data.len()
666                    && ip2 + 4 <= data.len()
667                    && read32(data, ip2) == read32(data, rep_cand)
668                {
669                    // Rep match at ip2! Write hash for ip1 first
670                    ht[h1] = ip1 as u32;
671
672                    let mut mlen = 4 + count_match(data, ip2 + 4, rep_cand + 4);
673                    // Backward extension
674                    let mut start = ip2;
675                    let mut mstart = rep_cand;
676                    while start > anchor
677                        && mstart > 0
678                        && mlen < MAX_MATCH_LEN
679                        && data[start - 1] == data[mstart - 1]
680                    {
681                        start -= 1;
682                        mstart -= 1;
683                        mlen += 1;
684                    }
685
686                    let ll = (start - anchor) as u32;
687                    sequences.push(Sequence {
688                        ll,
689                        off: rep1,
690                        ml: mlen as u32,
691                    });
692                    ip0 = start + mlen;
693                    anchor = ip0;
694
695                    // Rep-code chaining: check rep2 at new position
696                    rep_chain(
697                        data,
698                        &mut ip0,
699                        &mut anchor,
700                        &mut sequences,
701                        &mut rep1,
702                        &mut rep2,
703                        &mut ht,
704                        hlog,
705                        mls,
706                        ilimit,
707                    );
708                    continue 'outer;
709                }
710            }
711
712            // === Hash match check at ip0 ===
713            if match_idx0 < ip0
714                && ip0 - match_idx0 <= (1 << 24)
715                && match_idx0 + 4 <= data.len()
716                && read32(data, ip0) == read32(data, match_idx0)
717            {
718                ht[h1] = ip1 as u32;
719                let match0 = match_idx0;
720
721                let mut mlen = 4 + count_match(data, ip0 + 4, match0 + 4);
722                let mut start = ip0;
723                let mut mstart = match0;
724                while start > anchor && mstart > 0 && data[start - 1] == data[mstart - 1] {
725                    start -= 1;
726                    mstart -= 1;
727                    mlen += 1;
728                }
729
730                let offset = (start - mstart) as u32;
731                rep2 = rep1;
732                rep1 = offset;
733
734                let ll = (start - anchor) as u32;
735                sequences.push(Sequence {
736                    ll,
737                    off: offset,
738                    ml: mlen as u32,
739                });
740                ip0 = start + mlen;
741                anchor = ip0;
742
743                // Fill hash table for end-of-match positions
744                if ip0 > 2 && ip0 + MIN_INPUT <= data.len() {
745                    ht[hash_n(&data[ip0 - 2..], (hash_size - 1) as u32, mls)] = (ip0 - 2) as u32;
746                }
747
748                rep_chain(
749                    data,
750                    &mut ip0,
751                    &mut anchor,
752                    &mut sequences,
753                    &mut rep1,
754                    &mut rep2,
755                    &mut ht,
756                    hlog,
757                    mls,
758                    ilimit,
759                );
760                continue 'outer;
761            }
762
763            // === No match at ip0 — check ip1 ===
764            let match_idx1 = ht[h1] as usize;
765            h0 = hash_n(&data[ip2..], (hash_size - 1) as u32, mls);
766            ht[h1] = ip1 as u32;
767
768            if match_idx1 < ip1
769                && ip1 - match_idx1 <= (1 << 24)
770                && match_idx1 + 4 <= data.len()
771                && read32(data, ip1) == read32(data, match_idx1)
772            {
773                let match0 = match_idx1;
774                let mut mlen = 4 + count_match(data, ip1 + 4, match0 + 4);
775                let mut start = ip1;
776                let mut mstart = match0;
777                while start > anchor && mstart > 0 && data[start - 1] == data[mstart - 1] {
778                    start -= 1;
779                    mstart -= 1;
780                    mlen += 1;
781                }
782
783                let offset = (start - mstart) as u32;
784                rep2 = rep1;
785                rep1 = offset;
786
787                let ll = (start - anchor) as u32;
788                sequences.push(Sequence {
789                    ll,
790                    off: offset,
791                    ml: mlen as u32,
792                });
793                ip0 = start + mlen;
794                anchor = ip0;
795
796                if ip0 > 2 && ip0 + MIN_INPUT <= data.len() {
797                    ht[hash_n(&data[ip0 - 2..], (hash_size - 1) as u32, mls)] = (ip0 - 2) as u32;
798                }
799
800                rep_chain(
801                    data,
802                    &mut ip0,
803                    &mut anchor,
804                    &mut sequences,
805                    &mut rep1,
806                    &mut rep2,
807                    &mut ht,
808                    hlog,
809                    mls,
810                    ilimit,
811                );
812                continue 'outer;
813            }
814
815            // No match at ip0 or ip1 — advance with step
816            ip0 = ip2;
817            ip1 = ip3;
818
819            // Step increment: search less densely in non-matching regions
820            if ip0 >= next_step {
821                step += 1;
822                next_step += K_STEP_INCR;
823            }
824        }
825    }
826
827    sequences
828}
829
830/// Rep-code chaining: after a match, check if the next position matches rep2.
831#[allow(clippy::too_many_arguments)]
832#[inline]
833fn rep_chain(
834    data: &[u8],
835    ip: &mut usize,
836    anchor: &mut usize,
837    sequences: &mut Vec<Sequence>,
838    rep1: &mut u32,
839    rep2: &mut u32,
840    ht: &mut [u32],
841    hlog: u32,
842    mls: usize,
843    _ilimit: usize,
844) {
845    let hash_size = 1usize << hlog;
846    while *rep2 > 0 && *ip + 4 <= data.len() && *ip >= *rep2 as usize {
847        let cand = *ip - *rep2 as usize;
848        if cand + 4 > data.len() || read32(data, *ip) != read32(data, cand) {
849            break;
850        }
851        let mlen = 4 + count_match(data, *ip + 4, cand + 4);
852        // Swap rep codes
853        std::mem::swap(rep1, rep2);
854
855        // Update hash table
856        if *ip + 8 <= data.len() {
857            ht[hash_n(&data[*ip..], (hash_size - 1) as u32, mls)] = *ip as u32;
858        }
859
860        sequences.push(Sequence {
861            ll: 0,
862            off: *rep1,
863            ml: mlen as u32,
864        });
865        *ip += mlen;
866        *anchor = *ip;
867    }
868}
869
870#[inline]
871fn read32(data: &[u8], pos: usize) -> u32 {
872    u32::from_le_bytes([data[pos], data[pos + 1], data[pos + 2], data[pos + 3]])
873}
874
875/// Max match length. Capped to ensure ll+ml fits in BLOCKSIZE_MAX for any literal length.
876/// Using BLOCKSIZE_MAX directly since ll is typically small.
877const MAX_MATCH_LEN: usize = ZSTD_BLOCKSIZE_MAX;
878
879#[inline]
880fn count_match(data: &[u8], mut a: usize, mut b: usize) -> usize {
881    let start = a;
882    let max_extend = MAX_MATCH_LEN - 4;
883    let limit = std::cmp::min(data.len(), start + max_extend);
884    while a + 8 <= limit && b + 8 <= data.len() {
885        let va = u64::from_le_bytes(data[a..a + 8].try_into().unwrap());
886        let vb = u64::from_le_bytes(data[b..b + 8].try_into().unwrap());
887        if va != vb {
888            return a - start + (va ^ vb).trailing_zeros() as usize / 8;
889        }
890        a += 8;
891        b += 8;
892    }
893    while a < limit && b < data.len() && data[a] == data[b] {
894        a += 1;
895        b += 1;
896    }
897    a - start
898}
899
900/// Dual-hash match finder with rep-code priority and optional lazy evaluation.
901/// Uses two hash tables: short (4-byte) for nearby matches and long (7-byte)
902/// for long-range matches. Takes the best match from either table.
903/// - Rep-code matches checked first at each position (free offset cost)
904/// - After each match, rep-code chaining checks rep2 immediately
905/// - Lazy evaluation (level 3+) checks ip+1 for better matches
906fn find_matches_lazy(data: &[u8], params: &MatchParams) -> Vec<Sequence> {
907    if data.len() < 8 {
908        return vec![];
909    }
910
911    let hash_size = 1usize << params.hash_log;
912    let hash_mask = (hash_size - 1) as u32;
913    let long_hash_size = 1usize << std::cmp::min(params.hash_log, 17);
914    let long_hash_mask = (long_hash_size - 1) as u32;
915    let mut ht_short = vec![0u32; hash_size]; // 4-byte hash → position
916    let mut ht_long = vec![0u32; long_hash_size]; // 7-byte hash → position
917    let mut chain = vec![0u32; data.len()];
918    let mut sequences = Vec::new();
919    let mut anchor = 0usize;
920    let mut ip = 0usize;
921    let mut rep1 = 0u32;
922    let mut rep2 = 0u32;
923    let lazy = params.lazy_depth >= 1;
924
925    while ip + 8 <= data.len() {
926        // === 1. Rep-code check (near-free offset, ~12 bits overhead) ===
927        // Rep saves offset bits but still has LL+ML FSE overhead (~12 bits).
928        // A rep match of N bytes saves N*5 bits (literals), costs ~12 bits → profitable if N >= 3.
929        // But at N=4-5 with ll=2, the net savings are thin. Require N >= ZSTD_MINMATCH (always true).
930        let rep_match = if rep1 > 0 && ip >= rep1 as usize && ip + 4 <= data.len() {
931            let cand = ip - rep1 as usize;
932            if cand + 4 <= data.len() && read32(data, ip) == read32(data, cand) {
933                let ml = 4 + count_match(data, ip + 4, cand + 4);
934                // Rep overhead: ~12 bits (LL FSE + ML FSE, no offset bits)
935                // Savings: ml * 5 bits. Net positive if ml*5 > 20 → ml > 4
936                // Rep overhead: ~12 bits for LL+ML FSE.
937                // Savings: ml * literal_cost (~5 bits/byte).
938                // Net benefit must exceed per-sequence fixed cost to be worthwhile.
939                // For high-entropy data, a 6-byte rep match at ll=2 is marginal.
940                // Require ml >= 7 to skip borderline short matches that
941                // produce too many sequences (key insight from C zstd analysis).
942                if ml * 5 > 40 {
943                    Some((rep1 as usize, ml))
944                } else {
945                    None
946                }
947            } else {
948                None
949            }
950        } else {
951            None
952        };
953
954        // === 2. Hash-chain match using configured hash_bytes ===
955        let short_match = find_best_at_n(
956            data,
957            ip,
958            &ht_short,
959            &chain,
960            hash_mask,
961            params.search_depth,
962            std::cmp::min(params.hash_bytes, 4),
963        );
964
965        // === 3. Long hash match (7-byte, single lookup, for long-range) ===
966        let long_match = if ip + 7 <= data.len() {
967            let lh = hash7(&data[ip..], long_hash_mask);
968            let lidx = ht_long[lh] as usize;
969            ht_long[lh] = ip as u32;
970            if lidx < ip
971                && ip - lidx <= (1 << 24)
972                && lidx + 4 <= data.len()
973                && read32(data, ip) == read32(data, lidx)
974            {
975                let ml = 4 + count_match(data, ip + 4, lidx + 4);
976                Some((ip - lidx, ml))
977            } else {
978                None
979            }
980        } else {
981            None
982        };
983
984        // === 4. Pick best overall ===
985        // Filter hash matches for profitability (rep matches are always free)
986        let short_match = short_match.filter(|&(off, ml)| is_match_profitable(ml, off));
987        let long_match = long_match.filter(|&(off, ml)| is_match_profitable(ml, off));
988
989        let best_hash = match (short_match, long_match) {
990            (Some((so, sl)), Some((lo, ll))) => {
991                if ll >= sl + 2 {
992                    Some((lo, ll))
993                } else {
994                    Some((so, sl))
995                }
996            }
997            (Some(s), None) => Some(s),
998            (None, Some(l)) => Some(l),
999            (None, None) => None,
1000        };
1001
1002        let chosen = match (rep_match, best_hash) {
1003            (Some((roff, rml)), Some((hoff, hml))) => {
1004                let off_bits = 32u32.saturating_sub((hoff as u32).leading_zeros());
1005                if rml + (off_bits as usize / 4) >= hml {
1006                    Some((roff, rml))
1007                } else {
1008                    Some((hoff, hml))
1009                }
1010            }
1011            (Some(r), None) => Some(r),
1012            (None, Some(h)) => Some(h),
1013            (None, None) => None,
1014        };
1015
1016        if let Some((offset, match_len)) = chosen {
1017            let mut final_off = offset;
1018            let mut final_len = match_len;
1019            let mut final_ip = ip;
1020
1021            // Lazy: check ip+1 for better match
1022            if lazy && ip + 1 + 8 <= data.len() {
1023                insert_hash_n(&mut ht_short, &mut chain, data, ip, hash_mask, 4);
1024                let mut next_best = None;
1025                // Rep at ip+1
1026                if rep1 > 0 && ip + 1 >= rep1 as usize && ip + 5 <= data.len() {
1027                    let c = ip + 1 - rep1 as usize;
1028                    if c + 4 <= data.len() && read32(data, ip + 1) == read32(data, c) {
1029                        let rl = 4 + count_match(data, ip + 5, c + 4);
1030                        if rl > final_len {
1031                            next_best = Some((rep1 as usize, rl));
1032                        }
1033                    }
1034                }
1035                // Hash chain at ip+1
1036                if let Some((off2, len2)) = find_best_at_n(
1037                    data,
1038                    ip + 1,
1039                    &ht_short,
1040                    &chain,
1041                    hash_mask,
1042                    params.search_depth,
1043                    4,
1044                ) {
1045                    if len2 > final_len + 1 && (next_best.is_none() || len2 > next_best.unwrap().1)
1046                    {
1047                        next_best = Some((off2, len2));
1048                    }
1049                }
1050                if let Some((off2, len2)) = next_best {
1051                    if len2 > final_len + 1 {
1052                        final_off = off2;
1053                        final_len = len2;
1054                        final_ip = ip + 1;
1055                    }
1056                }
1057            }
1058
1059            let ll = (final_ip - anchor) as u32;
1060            sequences.push(Sequence {
1061                ll,
1062                off: final_off as u32,
1063                ml: final_len as u32,
1064            });
1065
1066            if final_off as u32 != rep1 {
1067                rep2 = rep1;
1068                rep1 = final_off as u32;
1069            }
1070
1071            let end = std::cmp::min(final_ip + final_len, data.len().saturating_sub(4));
1072            for p in ip..end {
1073                insert_hash_n(&mut ht_short, &mut chain, data, p, hash_mask, 4);
1074                if p + 7 <= data.len() {
1075                    ht_long[hash7(&data[p..], long_hash_mask)] = p as u32;
1076                }
1077            }
1078
1079            ip = final_ip + final_len;
1080            anchor = ip;
1081
1082            // === Rep-code chaining: check rep2 ===
1083            while rep2 > 0 && ip + 4 <= data.len() && ip >= rep2 as usize {
1084                let cand = ip - rep2 as usize;
1085                if cand + 4 > data.len() || read32(data, ip) != read32(data, cand) {
1086                    break;
1087                }
1088                let mlen = 4 + count_match(data, ip + 4, cand + 4);
1089                if mlen < ZSTD_MINMATCH {
1090                    break;
1091                }
1092
1093                std::mem::swap(&mut rep2, &mut rep1);
1094                sequences.push(Sequence {
1095                    ll: 0,
1096                    off: rep1,
1097                    ml: mlen as u32,
1098                });
1099
1100                let end2 = std::cmp::min(ip + mlen, data.len().saturating_sub(4));
1101                for p in ip..end2 {
1102                    insert_hash_n(&mut ht_short, &mut chain, data, p, hash_mask, 4);
1103                    if p + 7 <= data.len() {
1104                        ht_long[hash7(&data[p..], long_hash_mask)] = p as u32;
1105                    }
1106                }
1107                ip += mlen;
1108                anchor = ip;
1109            }
1110        } else {
1111            insert_hash_n(&mut ht_short, &mut chain, data, ip, hash_mask, 4);
1112            if ip + 7 <= data.len() {
1113                ht_long[hash7(&data[ip..], long_hash_mask)] = ip as u32;
1114            }
1115            ip += 1;
1116        }
1117    }
1118
1119    sequences
1120}
1121
1122/// Check if a match at `offset` of `length` bytes saves more than it costs.
1123///
1124/// Key insight from C zstd analysis: each sequence has ~17+ bits of FSE overhead
1125/// (LL state + OF state + ML state + offset extra bits). A match only "saves"
1126/// bytes that would otherwise be literals. Literals compress well with Huffman
1127/// (~5-6 bits/byte for typical data). So a match of N bytes saves ~N*5 bits
1128/// but costs ~17+offset_code bits.
1129///
1130/// For match_len=6, offset=8: saves 30 bits, costs ~20 bits → marginal
1131/// For match_len=4, offset=1000: saves 20 bits, costs ~27 bits → unprofitable!
1132#[inline]
1133fn is_match_profitable(match_len: usize, offset: usize) -> bool {
1134    let off_code = if offset > 1 {
1135        32 - (offset as u32).leading_zeros()
1136    } else {
1137        1
1138    };
1139    // Sequence overhead: ~6 (LL) + 5 (OF_FSE) + off_code (OF extra) + 6 (ML) = 17 + off_code bits
1140    // Match saves: match_len * literal_bits_per_byte
1141    // Use literal cost of 5 bits/byte (conservative Huffman estimate)
1142    let overhead_bits = 17 + off_code;
1143    let savings_bits = match_len as u32 * 5;
1144    savings_bits > overhead_bits + 8 // require >1 byte net benefit
1145}
1146
1147/// After a match ends, immediately check for a rep-code match at the current
1148/// position using rep[1] (the second repeat offset). This chains consecutive
1149/// matches without re-entering the main loop, matching C zstd behavior.
1150fn insert_hash_n(
1151    hash_table: &mut [u32],
1152    chain: &mut [u32],
1153    data: &[u8],
1154    pos: usize,
1155    mask: u32,
1156    hash_bytes: usize,
1157) {
1158    if pos + hash_bytes > data.len() {
1159        return;
1160    }
1161    let h = hash_n(&data[pos..], mask, hash_bytes);
1162    chain[pos] = hash_table[h];
1163    hash_table[h] = pos as u32;
1164}
1165
1166/// Insert position into hash chain (4-byte hash, used by lazy lookups).
1167#[inline]
1168fn hash_n(data: &[u8], mask: u32, n: usize) -> usize {
1169    match n {
1170        5 => hash5(data, mask),
1171        6 => hash6(data, mask),
1172        7 => hash7(data, mask),
1173        _ => hash4(data, mask),
1174    }
1175}
1176
1177/// 6-byte hash
1178#[inline]
1179fn hash6(data: &[u8], mask: u32) -> usize {
1180    let v = u64::from_le_bytes([data[0], data[1], data[2], data[3], data[4], data[5], 0, 0]);
1181    ((v.wrapping_mul(227718039650203u64)) >> 24) as usize & mask as usize
1182}
1183
1184/// 7-byte hash matching C zstd's ZSTD_hash7Ptr (prime=58295818150454627).
1185/// Critical for f64 data where first 4-6 bytes are often identical (0x00).
1186#[inline]
1187fn hash7(data: &[u8], mask: u32) -> usize {
1188    let v = u64::from_le_bytes([
1189        data[0], data[1], data[2], data[3], data[4], data[5], data[6], 0,
1190    ]);
1191    ((v.wrapping_mul(58295818150454627u64)) >> 24) as usize & mask as usize
1192}
1193
1194/// Find the best match at `pos` by walking the hash chain.
1195fn find_best_at_n(
1196    data: &[u8],
1197    pos: usize,
1198    hash_table: &[u32],
1199    chain: &[u32],
1200    mask: u32,
1201    max_depth: u32,
1202    hash_bytes: usize,
1203) -> Option<(usize, usize)> {
1204    if pos + hash_bytes > data.len() {
1205        return None;
1206    }
1207    let h = hash_n(&data[pos..], mask, hash_bytes);
1208    let mut candidate = hash_table[h] as usize;
1209    let mut best_len = ZSTD_MINMATCH - 1;
1210    let mut best_off = 0;
1211
1212    for _ in 0..max_depth {
1213        if candidate >= pos || pos - candidate > (1 << 24) {
1214            break;
1215        }
1216        if candidate + ZSTD_MINMATCH > data.len() {
1217            break;
1218        }
1219
1220        // Quick 4-byte check
1221        if data[candidate..candidate + 4] == data[pos..pos + 4] {
1222            // Cap match length at spec maximum (ML code 52: 65539 + 65535 = 131074)
1223            let max_ml = std::cmp::min(ZSTD_BLOCKSIZE_MAX, data.len() - pos);
1224            let cand_max = std::cmp::min(max_ml, data.len() - candidate);
1225            let ml = common_prefix_len(
1226                &data[candidate..candidate + cand_max],
1227                &data[pos..pos + cand_max],
1228            );
1229            if ml > best_len {
1230                best_len = ml;
1231                best_off = pos - candidate;
1232            }
1233        }
1234
1235        let next = chain[candidate] as usize;
1236        if next >= candidate {
1237            break;
1238        }
1239        candidate = next;
1240    }
1241
1242    if best_len >= ZSTD_MINMATCH {
1243        Some((best_off, best_len))
1244    } else {
1245        None
1246    }
1247}
1248
1249/// Fast common prefix length using 8-byte chunks.
1250#[inline]
1251fn common_prefix_len(a: &[u8], b: &[u8]) -> usize {
1252    let max = std::cmp::min(a.len(), b.len());
1253    let mut i = 0;
1254    while i + 8 <= max {
1255        let va = u64::from_le_bytes(a[i..i + 8].try_into().unwrap());
1256        let vb = u64::from_le_bytes(b[i..i + 8].try_into().unwrap());
1257        if va != vb {
1258            return i + ((va ^ vb).trailing_zeros() / 8) as usize;
1259        }
1260        i += 8;
1261    }
1262    while i < max && a[i] == b[i] {
1263        i += 1;
1264    }
1265    i
1266}
1267
1268/// 5-byte multiplicative hash for better collision avoidance.
1269/// Matches C zstd's hash5 using prime 889523592379.
1270#[inline]
1271fn hash5(data: &[u8], mask: u32) -> usize {
1272    let v = u64::from_le_bytes([data[0], data[1], data[2], data[3], data[4], 0, 0, 0]);
1273    ((v.wrapping_mul(889523592379u64)) >> 24) as usize & mask as usize
1274}
1275
1276/// 4-byte multiplicative hash, result masked to table size.
1277#[inline]
1278fn hash4(data: &[u8], mask: u32) -> usize {
1279    let v = u32::from_le_bytes([data[0], data[1], data[2], data[3]]);
1280    (v.wrapping_mul(0x9E3779B1) as usize) & (mask as usize)
1281}
1282
1283// =========================================================================
1284// Huffman literal compression
1285// =========================================================================
1286
1287/// Build codes for Treeless reuse. Roundtrip-verifies through encode_huffman_tree
1288/// + decode to ensure encoder codes exactly match what decoder reconstructs.
1289/// Encode literals using a previous Huffman tree (Treeless_Literals_Block, type=3).
1290#[cfg(not(feature = "parallel"))]
1291fn encode_literals_treeless(literals: &[u8], prev_codes: &[(u32, u8); 256]) -> Option<Vec<u8>> {
1292    let use_4 = literals.len() >= 1024;
1293    let streams = if use_4 {
1294        encode_huf_4streams(literals, prev_codes)
1295    } else {
1296        encode_huf_1stream(literals, prev_codes)
1297    };
1298    let regen = literals.len();
1299    let comp = streams.len();
1300    if comp >= regen {
1301        return None;
1302    }
1303    let lh_size = 3 + (regen >= 1024) as usize + (regen >= 16384) as usize;
1304    let mut out = Vec::with_capacity(lh_size + comp);
1305    let htype = LIT_TYPE_TREELESS as u32;
1306    match lh_size {
1307        3 => {
1308            let sf = if use_4 { 1u32 } else { 0 };
1309            out.extend_from_slice(
1310                &(htype | (sf << 2) | ((regen as u32) << 4) | ((comp as u32) << 14)).to_le_bytes()
1311                    [..3],
1312            );
1313        }
1314        4 => out.extend_from_slice(
1315            &(htype | (2u32 << 2) | ((regen as u32) << 4) | ((comp as u32) << 18)).to_le_bytes()
1316                [..4],
1317        ),
1318        _ => {
1319            let v = htype | (3u32 << 2) | ((regen as u32) << 4) | ((comp as u32) << 22);
1320            out.extend_from_slice(&v.to_le_bytes()[..4]);
1321            out.push((comp >> 10) as u8);
1322        }
1323    }
1324    out.extend_from_slice(&streams);
1325    Some(out)
1326}
1327
1328fn encode_literals_huffman(literals: &[u8]) -> Option<Vec<u8>> {
1329    // Count frequencies
1330    let mut counts = [0u32; 256];
1331    let mut max_sym = 0u8;
1332    for &b in literals {
1333        counts[b as usize] += 1;
1334        if b > max_sym {
1335            max_sym = b;
1336        }
1337    }
1338    let n_used = counts.iter().filter(|&&c| c > 0).count();
1339    if n_used < 2 {
1340        return None;
1341    }
1342
1343    // Build length-limited Huffman (max 11 bits)
1344    let (codes, max_bits) = build_huffman_codes(&counts, max_sym as usize)?;
1345
1346    // Encode tree description (weights packed as 4-bit pairs)
1347    let tree_desc = encode_huffman_tree(&codes, max_bits, max_sym as usize);
1348    if tree_desc.is_empty() {
1349        return None;
1350    }
1351
1352    // Encode streams: single stream for < 1KB, 4 streams for >= 1KB
1353    let use_4 = literals.len() >= 1024;
1354    let streams = if use_4 {
1355        encode_huf_4streams(literals, &codes)
1356    } else {
1357        encode_huf_1stream(literals, &codes)
1358    };
1359
1360    let regen = literals.len();
1361    let comp = tree_desc.len() + streams.len();
1362    let lh_size = 3 + (regen >= 1024) as usize + (regen >= 16384) as usize;
1363
1364    let mut out = Vec::with_capacity(lh_size + comp);
1365    let htype = LIT_TYPE_COMPRESSED as u32;
1366
1367    match lh_size {
1368        3 => {
1369            // bit[1:0]=type(2), bit[2]=streams_flag, bit[3]=0, bit[13:4]=regen, bit[23:14]=comp
1370            let sf = if use_4 { 1u32 } else { 0u32 };
1371            let lhc = htype | (sf << 2) | ((regen as u32) << 4) | ((comp as u32) << 14);
1372            out.extend_from_slice(&lhc.to_le_bytes()[..3]);
1373        }
1374        4 => {
1375            let lhc = htype | (2u32 << 2) | ((regen as u32) << 4) | ((comp as u32) << 18);
1376            out.extend_from_slice(&lhc.to_le_bytes()[..4]);
1377        }
1378        _ => {
1379            let lhc = htype | (3u32 << 2) | ((regen as u32) << 4) | ((comp as u32) << 22);
1380            out.extend_from_slice(&lhc.to_le_bytes()[..4]);
1381            out.push((comp >> 10) as u8);
1382        }
1383    }
1384
1385    out.extend_from_slice(&tree_desc);
1386    out.extend_from_slice(&streams);
1387    Some(out)
1388}
1389
1390/// Build Huffman codes — faithful port of C zstd's HUF_buildCTable_wksp.
1391/// 1. Sort symbols descending by count
1392/// 2. Build binary Huffman tree (two-queue merge)
1393/// 3. Enforce max depth via HUF_setMaxHeight
1394/// 4. Generate canonical codes using C zstd's min >>= 1 algorithm
1395fn build_huffman_codes(counts: &[u32; 256], max_sym: usize) -> Option<([(u32, u8); 256], u8)> {
1396    const MAX_BITS: u8 = 11;
1397
1398    // Collect and sort symbols descending by count
1399    let mut syms: Vec<(u32, u8)> = (0..=max_sym)
1400        .filter(|&s| counts[s] > 0)
1401        .map(|s| (counts[s], s as u8))
1402        .collect();
1403    syms.sort_by(|a, b| b.0.cmp(&a.0));
1404    let n = syms.len();
1405    if n < 2 {
1406        return None;
1407    }
1408
1409    // --- Step 1: Build Huffman tree (two-queue merge) ---
1410    let mut node_count = vec![0u64; 2 * n];
1411    let mut node_parent = vec![0u32; 2 * n];
1412    let mut node_nbits = vec![0u8; 2 * n];
1413    for i in 0..n {
1414        node_count[i] = syms[i].0 as u64;
1415    }
1416    for i in n..2 * n {
1417        node_count[i] = u64::MAX / 2;
1418    }
1419
1420    let mut low_s = n as i32 - 1;
1421    let mut low_n = n;
1422    let mut next_node = n;
1423
1424    // Helper: pick smallest from symbol queue or node queue
1425    let pick_smallest =
1426        |node_count: &[u64], low_s: &mut i32, low_n: &mut usize, next_node: usize| -> usize {
1427            if *low_s >= 0
1428                && (*low_n >= next_node || node_count[*low_s as usize] < node_count[*low_n])
1429            {
1430                let r = *low_s as usize;
1431                *low_s -= 1;
1432                r
1433            } else if *low_n < next_node {
1434                let r = *low_n;
1435                *low_n += 1;
1436                r
1437            } else {
1438                usize::MAX // shouldn't happen
1439            }
1440        };
1441
1442    while next_node < 2 * n - 1 {
1443        let n1 = pick_smallest(&node_count, &mut low_s, &mut low_n, next_node);
1444        let n2 = pick_smallest(&node_count, &mut low_s, &mut low_n, next_node);
1445        if n1 == usize::MAX || n2 == usize::MAX {
1446            break;
1447        }
1448        node_count[next_node] = node_count[n1] + node_count[n2];
1449        node_parent[n1] = next_node as u32;
1450        node_parent[n2] = next_node as u32;
1451        next_node += 1;
1452    }
1453    let root = next_node - 1;
1454
1455    // Assign bit lengths top-down
1456    node_nbits[root] = 0;
1457    for i in (n..=root).rev() {
1458        if i < root {
1459            node_nbits[i] = node_nbits[node_parent[i] as usize] + 1;
1460        }
1461    }
1462    for i in 0..n {
1463        node_nbits[i] = node_nbits[node_parent[i] as usize] + 1;
1464    }
1465
1466    // --- Step 2: HUF_setMaxHeight ---
1467    // If tree is too deep, fall back to raw instead of trying to fix it.
1468    // Uses rankLast[] to track the least-frequent symbol at each depth,
1469    // ensuring both Kraft inequality and valid weight-sum.
1470    let largest_bits = *node_nbits[..n].iter().max().unwrap_or(&0);
1471    if largest_bits > MAX_BITS {
1472        let target = MAX_BITS;
1473
1474        // Phase 1: Clamp all > target to target, compute totalCost
1475        let base_cost = 1i32 << (largest_bits - target);
1476        let mut total_cost = 0i32;
1477        // Scan backward (least frequent first, they have the longest codes)
1478        let mut last_non_null = n - 1;
1479        while node_nbits[last_non_null] > target {
1480            total_cost += base_cost - (1i32 << (largest_bits - node_nbits[last_non_null]));
1481            node_nbits[last_non_null] = target;
1482            if last_non_null == 0 {
1483                break;
1484            }
1485            last_non_null -= 1;
1486        }
1487        total_cost >>= largest_bits - target;
1488
1489        // Build rankLast[]: position of last (least frequent) symbol at each rank
1490        // rankLast[k] = index of least-frequent symbol using (target - k) bits
1491        const NO_SYMBOL: u32 = 0xF0F0F0F0;
1492        let mut rank_last = [NO_SYMBOL; 16];
1493        {
1494            let mut current_bits = target;
1495            for pos in (0..=last_non_null).rev() {
1496                if node_nbits[pos] >= current_bits {
1497                    continue;
1498                }
1499                current_bits = node_nbits[pos];
1500                rank_last[(target - current_bits) as usize] = pos as u32;
1501            }
1502        }
1503
1504        // Phase 2: Repay cost by lengthening symbols (increasing their nbBits)
1505        while total_cost > 0 {
1506            // Find the best rank to decrease: target the next power-of-2 chunk
1507            let mut n_bits_to_decrease = 32 - (total_cost as u32).leading_zeros();
1508            // but don't exceed available ranks
1509            if n_bits_to_decrease > largest_bits as u32 - target as u32 + 1 {
1510                n_bits_to_decrease = largest_bits as u32 - target as u32 + 1;
1511            }
1512
1513            // Try to find best rank: prefer promoting cheap symbols
1514            while n_bits_to_decrease > 1 {
1515                let high_pos = rank_last[n_bits_to_decrease as usize];
1516                let low_pos = rank_last[n_bits_to_decrease as usize - 1];
1517                if high_pos == NO_SYMBOL {
1518                    n_bits_to_decrease -= 1;
1519                    continue;
1520                }
1521                if low_pos == NO_SYMBOL {
1522                    break;
1523                }
1524                let high_total = syms[high_pos as usize].0;
1525                let low_total = 2 * syms[low_pos as usize].0;
1526                if high_total <= low_total {
1527                    break;
1528                }
1529                n_bits_to_decrease -= 1;
1530            }
1531
1532            // Find a non-empty rank if current is empty
1533            while n_bits_to_decrease as usize <= 14
1534                && rank_last[n_bits_to_decrease as usize] == NO_SYMBOL
1535            {
1536                n_bits_to_decrease += 1;
1537            }
1538            if n_bits_to_decrease as usize > 14
1539                || rank_last[n_bits_to_decrease as usize] == NO_SYMBOL
1540            {
1541                break; // can't repay
1542            }
1543
1544            // Promote the symbol: increase its nbBits by 1 (C allows overshoot here)
1545            total_cost -= 1i32 << (n_bits_to_decrease - 1);
1546            let pos = rank_last[n_bits_to_decrease as usize] as usize;
1547            node_nbits[pos] += 1;
1548
1549            // Update rankLast for the new rank
1550            if rank_last[n_bits_to_decrease as usize - 1] == NO_SYMBOL {
1551                rank_last[n_bits_to_decrease as usize - 1] = rank_last[n_bits_to_decrease as usize];
1552            }
1553
1554            // Update rankLast for the old rank
1555            if rank_last[n_bits_to_decrease as usize] == 0 {
1556                rank_last[n_bits_to_decrease as usize] = NO_SYMBOL;
1557            } else {
1558                let prev = rank_last[n_bits_to_decrease as usize] - 1;
1559                rank_last[n_bits_to_decrease as usize] = prev;
1560                if node_nbits[prev as usize] != target - n_bits_to_decrease as u8 {
1561                    rank_last[n_bits_to_decrease as usize] = NO_SYMBOL;
1562                }
1563            }
1564        }
1565
1566        // Phase 3: Overshoot correction (totalCost < 0)
1567        // Port of C zstd: demote rank-0 symbols to rank-1 (decrease nbBits by 1)
1568        while total_cost < 0 {
1569            if rank_last[1] == NO_SYMBOL {
1570                // No rank-1 symbols. Find last rank-0 symbol and demote it.
1571                let mut p = last_non_null;
1572                while p > 0 && node_nbits[p] == target {
1573                    p -= 1;
1574                }
1575                // p+1 is a rank-0 symbol (using target bits)
1576                if p + 1 < n && node_nbits[p + 1] == target {
1577                    node_nbits[p + 1] -= 1; // demote: target → target-1
1578                    rank_last[1] = (p + 1) as u32;
1579                    total_cost += 1;
1580                } else {
1581                    break; // can't correct
1582                }
1583            } else {
1584                // Demote the symbol just after rankLast[1] boundary
1585                let next = rank_last[1] as usize + 1;
1586                if next < n && node_nbits[next] == target {
1587                    node_nbits[next] -= 1;
1588                    rank_last[1] += 1;
1589                    total_cost += 1;
1590                } else {
1591                    // rankLast[1]+1 is not a rank-0 symbol, need to find one
1592                    rank_last[1] = NO_SYMBOL;
1593                    // Will retry with the NO_SYMBOL path above
1594                }
1595            }
1596        }
1597
1598        // If still not zero, fall back
1599        if total_cost != 0 {
1600            for i in 0..n {
1601                node_nbits[i] = 0;
1602            } // will fail Kraft
1603        }
1604    }
1605
1606    // --- Step 3: Extract code lengths and validate ---
1607    let mut lengths = [0u8; 256];
1608    for i in 0..n {
1609        lengths[syms[i].1 as usize] = node_nbits[i];
1610    }
1611
1612    let max_bits = *lengths.iter().max().unwrap_or(&0);
1613    if max_bits == 0 {
1614        return None;
1615    }
1616
1617    // Verify Kraft inequality: sum of 2^(max-len) must equal 2^max
1618    let kraft: u64 = (0..=max_sym)
1619        .filter(|&s| lengths[s] > 0)
1620        .map(|s| 1u64 << (max_bits - lengths[s]))
1621        .sum();
1622    if kraft != (1u64 << max_bits) {
1623        return None;
1624    }
1625
1626    // Weight-sum is automatically valid when Kraft is valid:
1627    // encode_huffman_tree pops the last non-zero weight, and the remaining
1628    // weight_sum = kraft_sum - 2^(last_weight-1) = 2^max - 2^k, which
1629    // leaves a valid power-of-2 leftover for the implicit last weight.
1630
1631    // Count symbols per rank
1632    let mut nb_per_rank = [0u32; 16];
1633    for &l in &lengths {
1634        if l > 0 {
1635            nb_per_rank[l as usize] += 1;
1636        }
1637    }
1638
1639    // zstd-style canonical code generation — exact mirror of decoder's rank_indexes.
1640    // Decoder: rank_indexes[max_bits] = 0
1641    //          rank_indexes[bits-1] = rank_indexes[bits] + bit_ranks[bits] * (1 << (max_bits - bits))
1642    // Code for a symbol at rank `bits` = rank_indexes[bits] / (1 << (max_bits - bits))
1643    let mut rank_indexes = [0u32; 16];
1644    rank_indexes[max_bits as usize] = 0;
1645    for bits in (1..=max_bits as usize).rev() {
1646        rank_indexes[bits - 1] =
1647            rank_indexes[bits] + nb_per_rank[bits] * (1u32 << (max_bits as usize - bits));
1648    }
1649
1650    // Assign codes: within each bit length, symbols get consecutive codes
1651    let mut next_code = [0u32; 16];
1652    for bits in 1..=max_bits as usize {
1653        next_code[bits] = rank_indexes[bits] >> (max_bits as usize - bits);
1654    }
1655
1656    let mut codes = [(0u32, 0u8); 256];
1657    for s in 0..=max_sym {
1658        if lengths[s] > 0 {
1659            codes[s] = (next_code[lengths[s] as usize], lengths[s]);
1660            next_code[lengths[s] as usize] += 1;
1661        }
1662    }
1663
1664    Some((codes, max_bits))
1665}
1666
1667fn encode_huffman_tree(codes: &[(u32, u8); 256], max_bits: u8, max_sym: usize) -> Vec<u8> {
1668    if max_bits == 0 {
1669        return vec![];
1670    }
1671    let mut weights: Vec<u8> = (0..=max_sym)
1672        .map(|s| {
1673            if codes[s].1 > 0 {
1674                max_bits + 1 - codes[s].1
1675            } else {
1676                0
1677            }
1678        })
1679        .collect();
1680    while weights.last() == Some(&0) && weights.len() > 1 {
1681        weights.pop();
1682    }
1683    if !weights.is_empty() {
1684        weights.pop();
1685    } // last weight is implicit
1686    if weights.is_empty() || weights.len() > 255 {
1687        return vec![];
1688    }
1689
1690    // Check all weights fit in 4 bits
1691    if weights.iter().any(|&w| w > 12) {
1692        return vec![];
1693    }
1694
1695    let num = weights.len();
1696
1697    if num <= 128 {
1698        // Direct mode: header = num + 127, packed 4-bit pairs
1699        let mut desc = Vec::with_capacity(1 + num.div_ceil(2));
1700        desc.push((num as u8) + 127);
1701        for pair in weights.chunks(2) {
1702            let w0 = pair[0];
1703            let w1 = if pair.len() > 1 { pair[1] } else { 0 };
1704            desc.push((w0 << 4) | (w1 & 0x0F));
1705        }
1706        desc
1707    } else {
1708        // >128 weights: FSE-compressed 2-stream interleaved encoding
1709        let fse_result = encode_weights_fse(&weights);
1710        match fse_result {
1711            Some(compressed) if compressed.len() < 127 => {
1712                let header_byte = compressed.len() as u8;
1713                // Verify roundtrip — only use if weights match exactly
1714                let verify = crate::decode::decode_huf_weights_from_fse(&compressed, header_byte);
1715                if let Ok(ref dw) = verify {
1716                    if *dw == weights {
1717                        let mut desc = Vec::with_capacity(1 + compressed.len());
1718                        desc.push(header_byte);
1719                        desc.extend_from_slice(&compressed);
1720                        return desc;
1721                    }
1722                }
1723                // FSE roundtrip failed — fall through to direct mode fallback
1724                if num <= 128 {
1725                    let mut desc = Vec::with_capacity(1 + num.div_ceil(2));
1726                    desc.push((num as u8) + 127);
1727                    for pair in weights.chunks(2) {
1728                        let w0 = pair[0];
1729                        let w1 = if pair.len() > 1 { pair[1] } else { 0 };
1730                        desc.push((w0 << 4) | (w1 & 0x0F));
1731                    }
1732                    desc
1733                } else {
1734                    vec![]
1735                }
1736            }
1737            _ => {
1738                // FSE too large or failed — try direct if num <= 128
1739                if num <= 128 {
1740                    let mut desc = Vec::with_capacity(1 + num.div_ceil(2));
1741                    desc.push((num as u8) + 127);
1742                    for pair in weights.chunks(2) {
1743                        let w0 = pair[0];
1744                        let w1 = if pair.len() > 1 { pair[1] } else { 0 };
1745                        desc.push((w0 << 4) | (w1 & 0x0F));
1746                    }
1747                    desc
1748                } else {
1749                    vec![] // can't encode 129+ weights without FSE
1750                }
1751            }
1752        }
1753    }
1754}
1755
1756/// Calculate baseline and num_bits for FSE decode table entry.
1757/// EXACT copy of decode.rs fse_calc_baseline_and_numbits.
1758/// FSE-compress weights using 2-stream interleaved encoding.
1759/// Uses decoder's FSE table directly for encoding to guarantee compatibility.
1760fn encode_weights_fse(weights: &[u8]) -> Option<Vec<u8>> {
1761    let mut counts = [0u32; 13];
1762    let mut max_w = 0u8;
1763    for &w in weights {
1764        counts[w as usize] += 1;
1765        if w > max_w {
1766            max_w = w;
1767        }
1768    }
1769    if max_w == 0 {
1770        return None;
1771    }
1772
1773    let table_log = 6u32;
1774    let table_size = 1u32 << table_log;
1775    let total = weights.len() as u32;
1776
1777    // Normalize
1778    let mut norm = [0i16; 13];
1779    let mut dist = 0u32;
1780    for s in 0..=max_w as usize {
1781        if counts[s] == 0 {
1782            continue;
1783        }
1784        norm[s] = std::cmp::max(
1785            1,
1786            (counts[s] as u64 * table_size as u64 / total as u64) as i16,
1787        );
1788        dist += norm[s] as u32;
1789    }
1790    while dist > table_size {
1791        for s in 0..=max_w as usize {
1792            if norm[s] > 1 {
1793                norm[s] -= 1;
1794                dist -= 1;
1795                break;
1796            }
1797        }
1798    }
1799    while dist < table_size {
1800        let best = (0..=max_w as usize).max_by_key(|&s| counts[s]).unwrap_or(0);
1801        norm[best] += 1;
1802        dist += 1;
1803    }
1804
1805    let fse = super::fse::FseCTable::build(&norm, max_w as usize, table_log);
1806
1807    // --- FSE table header (must match decode.rs read_probabilities exactly) ---
1808    // Decoder: accuracy_log = 5 + get_bits(4)
1809    //          loop: max_remaining = prob_sum - counter + 1
1810    //                bits_to_read = highest_bit_set(max_remaining)
1811    //                read bits_to_read, apply low_threshold logic
1812    //                prob = value - 1
1813    let mut hdr = Vec::with_capacity(16);
1814    let mut bb: u64 = (table_log - 5) as u64;
1815    let mut bp = 4u32;
1816    let prob_sum = table_size;
1817    let mut counter = 0u32;
1818
1819    let mut s = 0usize;
1820    while s <= max_w as usize && counter < prob_sum {
1821        let prob = norm[s] as i32;
1822        let value = (prob + 1) as u32;
1823
1824        let max_remaining = prob_sum - counter + 1;
1825        let bits_to_read = 32 - max_remaining.leading_zeros();
1826        let low_threshold = ((1u32 << bits_to_read) - 1) - max_remaining;
1827        let mask = (1u32 << (bits_to_read - 1)) - 1;
1828
1829        if value < low_threshold {
1830            bb |= (value as u64) << bp;
1831            bp += bits_to_read - 1;
1832        } else if value <= mask {
1833            bb |= (value as u64) << bp;
1834            bp += bits_to_read;
1835        } else {
1836            bb |= ((value + low_threshold) as u64) << bp;
1837            bp += bits_to_read;
1838        }
1839        while bp >= 8 {
1840            hdr.push(bb as u8);
1841            bb >>= 8;
1842            bp -= 8;
1843        }
1844
1845        if prob > 0 {
1846            counter += prob as u32;
1847        } else if prob == -1 {
1848            counter += 1;
1849        }
1850
1851        if prob == 0 {
1852            let mut repeat = 0u32;
1853            while s + 1 + repeat as usize <= max_w as usize
1854                && norm[s + 1 + repeat as usize] == 0
1855                && repeat < 3
1856            {
1857                repeat += 1;
1858            }
1859            bb |= (repeat as u64) << bp;
1860            bp += 2;
1861            while bp >= 8 {
1862                hdr.push(bb as u8);
1863                bb >>= 8;
1864                bp -= 8;
1865            }
1866            s += repeat as usize;
1867            while repeat == 3 {
1868                repeat = 0;
1869                while s + 1 + repeat as usize <= max_w as usize
1870                    && norm[s + 1 + repeat as usize] == 0
1871                    && repeat < 3
1872                {
1873                    repeat += 1;
1874                }
1875                bb |= (repeat as u64) << bp;
1876                bp += 2;
1877                while bp >= 8 {
1878                    hdr.push(bb as u8);
1879                    bb >>= 8;
1880                    bp -= 8;
1881                }
1882                s += repeat as usize;
1883            }
1884        }
1885        s += 1;
1886    }
1887    if bp > 0 {
1888        hdr.push(bb as u8);
1889    }
1890
1891    // --- Build decoder-compatible FSE table and encode using it ---
1892    // This guarantees encode/decode compatibility by using the same table structure.
1893    let ts = table_size as usize;
1894
1895    // Build decode table (same algorithm as decode.rs)
1896    let mut dec_symbol = vec![0u8; ts];
1897    let mut dec_baseline = vec![0u32; ts];
1898    let mut dec_numbits = vec![0u8; ts];
1899
1900    // Place -1 symbols at the end
1901    let mut neg_idx = ts;
1902    for s in 0..=max_w as usize {
1903        if norm[s] == -1 {
1904            neg_idx -= 1;
1905            dec_symbol[neg_idx] = s as u8;
1906            dec_baseline[neg_idx] = 0;
1907            dec_numbits[neg_idx] = table_log as u8;
1908        }
1909    }
1910
1911    // Spread remaining symbols
1912    let mut pos = 0usize;
1913    for s in 0..=max_w as usize {
1914        if norm[s] <= 0 {
1915            continue;
1916        }
1917        for _ in 0..norm[s] {
1918            dec_symbol[pos] = s as u8;
1919            pos += (ts >> 1) + (ts >> 3) + 3;
1920            pos &= ts - 1;
1921            while pos >= neg_idx {
1922                pos += (ts >> 1) + (ts >> 3) + 3;
1923                pos &= ts - 1;
1924            }
1925        }
1926    }
1927
1928    // CTable-based 2-stream interleaved encoding.
1929    // FSE backward encoding: init with LAST symbol, encode second-to-last down to first.
1930    // After all encodes, the final state carries the FIRST symbol (decoder's first output).
1931    let stream1: Vec<u8> = weights.iter().step_by(2).copied().collect();
1932    let stream2: Vec<u8> = weights.iter().skip(1).step_by(2).copied().collect();
1933    let len1 = stream1.len();
1934    let len2 = stream2.len();
1935
1936    // Init with LAST symbol of each stream (= first encoded, last decoded)
1937    let mut st1 = fse.init_state(*stream1.last().unwrap() as usize);
1938    let mut st2 = if len2 > 0 {
1939        fse.init_state(*stream2.last().unwrap() as usize)
1940    } else {
1941        0
1942    };
1943
1944    let mut bw = super::bitstream::BackwardBitWriter::new();
1945
1946    // Encode from second-to-last down to first (index 0).
1947    // After all encodes, state carries stream[0]'s symbol.
1948    // Decoder reads: init(state) → peek(stream[0]) → update → peek(stream[1]) → ...
1949    let _max_idx = std::cmp::max(len1, len2);
1950    // Encode each stream from second-to-last down to first.
1951    // init handles the last element, so encode [0..last-1].
1952    // Interleave order: decoder reads st1 update first, so write st2 first (backward).
1953    let max_encode = std::cmp::max(len1.saturating_sub(1), len2.saturating_sub(1));
1954    for i in (0..max_encode).rev() {
1955        if i < len2.saturating_sub(1) {
1956            let (bits, nb, ns) = fse.encode_symbol(st2, stream2[i] as usize);
1957            bw.add_bits(bits as u64, nb);
1958            bw.flush_bits();
1959            st2 = ns;
1960        }
1961        if i < len1.saturating_sub(1) {
1962            let (bits, nb, ns) = fse.encode_symbol(st1, stream1[i] as usize);
1963            bw.add_bits(bits as u64, nb);
1964            bw.flush_bits();
1965            st1 = ns;
1966        }
1967    }
1968
1969    // Write init states — convert CTable state to decoder index
1970    bw.add_bits((st2 - table_size) as u64, table_log);
1971    bw.flush_bits();
1972    bw.add_bits((st1 - table_size) as u64, table_log);
1973    bw.flush_bits();
1974
1975    let bitstream = bw.finish();
1976    let mut out = hdr;
1977    out.extend_from_slice(&bitstream);
1978    Some(out)
1979}
1980
1981/// Encode one Huffman stream (symbols in reverse, padded with sentinel bit).
1982fn encode_huf_1stream(data: &[u8], codes: &[(u32, u8); 256]) -> Vec<u8> {
1983    let mut bw = super::bitstream::BackwardBitWriter::new();
1984    // Encode symbols in reverse (backward bitstream convention)
1985    for &sym in data.iter().rev() {
1986        let (code, nb) = codes[sym as usize];
1987        if nb == 0 {
1988            continue;
1989        }
1990        bw.add_bits(code as u64, nb as u32);
1991        bw.flush_bits();
1992    }
1993    bw.finish() // adds sentinel, no reverse needed
1994}
1995
1996fn encode_huf_4streams(data: &[u8], codes: &[(u32, u8); 256]) -> Vec<u8> {
1997    let q = data.len().div_ceil(4);
1998    let ends = [
1999        q,
2000        std::cmp::min(q * 2, data.len()),
2001        std::cmp::min(q * 3, data.len()),
2002        data.len(),
2003    ];
2004    let starts = [0, q, ends[1], ends[2]];
2005
2006    let c: Vec<Vec<u8>> = (0..4)
2007        .map(|i| encode_huf_1stream(&data[starts[i]..ends[i]], codes))
2008        .collect();
2009
2010    let mut out = Vec::with_capacity(6 + c.iter().map(|v| v.len()).sum::<usize>());
2011    // Jump table: sizes of first 3 streams (u16 LE each)
2012    for i in 0..3 {
2013        out.extend_from_slice(&(c[i].len() as u16).to_le_bytes());
2014    }
2015    for stream in &c {
2016        out.extend_from_slice(stream);
2017    }
2018    out
2019}
2020
2021// =========================================================================
2022// Literals section encoding (Raw mode)
2023// =========================================================================
2024
2025fn encode_literals_rle(out: &mut Vec<u8>, byte: u8, size: usize) {
2026    if size <= 31 {
2027        out.push(LIT_TYPE_RLE | ((size as u8) << 3));
2028    } else if size <= 4095 {
2029        let h = (LIT_TYPE_RLE as u16) | (1 << 2) | ((size as u16) << 4);
2030        out.extend_from_slice(&h.to_le_bytes());
2031    } else {
2032        let h = (LIT_TYPE_RLE as u32) | (3 << 2) | ((size as u32) << 4);
2033        out.extend_from_slice(&h.to_le_bytes()[..3]);
2034    }
2035    out.push(byte);
2036}
2037
2038fn encode_literals_raw(out: &mut Vec<u8>, literals: &[u8]) {
2039    let size = literals.len();
2040
2041    if size <= 31 {
2042        // 1-byte header: type=0 (raw), size in 5 bits
2043        out.push(LIT_TYPE_RAW | ((size as u8) << 3));
2044    } else if size <= 4095 {
2045        // 2-byte header
2046        let h = (LIT_TYPE_RAW as u16) | (1 << 2) | ((size as u16) << 4);
2047        out.extend_from_slice(&h.to_le_bytes());
2048    } else {
2049        // 3-byte header
2050        let h = (LIT_TYPE_RAW as u32) | (3 << 2) | ((size as u32) << 4);
2051        out.extend_from_slice(&h.to_le_bytes()[..3]);
2052    }
2053
2054    out.extend_from_slice(literals);
2055}
2056
2057// =========================================================================
2058// Sequences section encoding using exact C-compatible FSE tables
2059// =========================================================================
2060
2061/// Encode sequences with cross-block Repeat mode support.
2062/// If the current block's symbol distribution matches the previous block, use Repeat mode
2063/// (no table header needed). Otherwise choose best of Predefined/RLE/Custom FSE.
2064fn encode_sequences_section(out: &mut Vec<u8>, sequences: &[EncodedSequence]) {
2065    let nb_seq = sequences.len();
2066
2067    // Number of sequences header
2068    if nb_seq < 128 {
2069        out.push(nb_seq as u8);
2070    } else if nb_seq < 0x7F00 {
2071        out.push(((nb_seq >> 8) as u8) + 128);
2072        out.push(nb_seq as u8);
2073    } else {
2074        out.push(255);
2075        out.extend_from_slice(&((nb_seq - 0x7F00) as u16).to_le_bytes());
2076    }
2077
2078    if nb_seq == 0 {
2079        return;
2080    }
2081
2082    // Convert sequences to codes + extra bit values
2083    let mut ll_codes_v = Vec::with_capacity(nb_seq);
2084    let mut ml_codes_v = Vec::with_capacity(nb_seq);
2085    let mut off_codes_v = Vec::with_capacity(nb_seq);
2086    let mut ll_values = Vec::with_capacity(nb_seq);
2087    let mut ml_values = Vec::with_capacity(nb_seq);
2088    let mut off_values = Vec::with_capacity(nb_seq);
2089
2090    for seq in sequences {
2091        let llc = ll_code(seq.ll);
2092        let ml_base = seq.ml - ZSTD_MINMATCH as u32;
2093        let mlc = ml_code(ml_base);
2094        let ofc = off_code(seq.of_value);
2095
2096        ll_codes_v.push(llc);
2097        ml_codes_v.push(mlc);
2098        off_codes_v.push(ofc);
2099        ll_values.push(seq.ll - LL_BASE[llc as usize]);
2100        ml_values.push(seq.ml - ML_BASE[mlc as usize]);
2101        off_values.push(if ofc > 0 {
2102            seq.of_value - (1u32 << ofc)
2103        } else {
2104            0
2105        });
2106    }
2107
2108    // Choose best mode for each table: Predefined vs RLE vs Custom FSE
2109    let ll_mode = choose_seq_mode(
2110        &ll_codes_v,
2111        MAX_LL,
2112        LL_DEFAULT_NORM_LOG,
2113        &LL_DEFAULT_NORM,
2114        LL_FSE_LOG,
2115    );
2116    let of_mode = choose_seq_mode(
2117        &off_codes_v,
2118        OF_DEFAULT_NORM.len() - 1,
2119        OF_DEFAULT_NORM_LOG,
2120        &OF_DEFAULT_NORM,
2121        OFF_FSE_LOG,
2122    );
2123    let ml_mode = choose_seq_mode(
2124        &ml_codes_v,
2125        MAX_ML,
2126        ML_DEFAULT_NORM_LOG,
2127        &ML_DEFAULT_NORM,
2128        ML_FSE_LOG,
2129    );
2130
2131    // Write compression modes byte
2132    let mode_byte = (ll_mode.tag() << 6) | (of_mode.tag() << 4) | (ml_mode.tag() << 2);
2133    out.push(mode_byte);
2134
2135    // Write table descriptions for non-predefined modes, then build tables
2136    let ll_table =
2137        write_seq_table_and_build(out, &ll_mode, &LL_DEFAULT_NORM, MAX_LL, LL_DEFAULT_NORM_LOG);
2138    let of_table = write_seq_table_and_build(
2139        out,
2140        &of_mode,
2141        &OF_DEFAULT_NORM,
2142        OF_DEFAULT_NORM.len() - 1,
2143        OF_DEFAULT_NORM_LOG,
2144    );
2145    let ml_table =
2146        write_seq_table_and_build(out, &ml_mode, &ML_DEFAULT_NORM, MAX_ML, ML_DEFAULT_NORM_LOG);
2147
2148    // Encode with FSE sequence encoder
2149    let bitstream = super::fse::encode_sequences(
2150        &ll_table,
2151        &of_table,
2152        &ml_table,
2153        &ll_codes_v,
2154        &off_codes_v,
2155        &ml_codes_v,
2156        &ll_values,
2157        &ml_values,
2158        &off_values,
2159    );
2160    out.extend_from_slice(&bitstream);
2161}
2162
2163// =========================================================================
2164// Custom FSE table mode selection for sequences
2165// =========================================================================
2166
2167/// Chosen compression mode for a sequence table.
2168enum SeqTableMode {
2169    Predefined,
2170    Rle(u8),
2171    Fse {
2172        norm: Vec<i16>,
2173        max_symbol: usize,
2174        table_log: u32,
2175        header_bytes: Vec<u8>,
2176    },
2177}
2178
2179impl SeqTableMode {
2180    fn tag(&self) -> u8 {
2181        match self {
2182            SeqTableMode::Predefined => SEQ_MODE_PREDEFINED,
2183            SeqTableMode::Rle(_) => SEQ_MODE_RLE,
2184            SeqTableMode::Fse { .. } => SEQ_MODE_FSE,
2185        }
2186    }
2187}
2188
2189/// Normalize symbol counts to probability distribution for FSE table.
2190/// Port of C zstd's FSE_normalizeCount() with 62-bit precision scaling.
2191fn normalize_counts(counts: &[u32], max_symbol: usize, table_log: u32) -> Vec<i16> {
2192    let table_size = 1u32 << table_log;
2193    let total: u64 = counts[..=max_symbol].iter().map(|&c| c as u64).sum();
2194    if total == 0 {
2195        return vec![0i16; max_symbol + 1];
2196    }
2197
2198    let mut norm = vec![0i16; max_symbol + 1];
2199
2200    // Use C zstd's high-precision scaling: step = (1<<62) / total
2201    let scale: u32 = 62 - table_log;
2202    let step: u64 = (1u64 << 62) / total;
2203    let v_step: u64 = 1u64 << (scale - 20);
2204    let low_threshold: u64 = total >> table_log;
2205
2206    // C zstd's rtbTable for precise rounding of small probabilities
2207    static RTB_TABLE: [u32; 8] = [0, 473195, 504333, 520860, 550000, 700000, 750000, 830000];
2208
2209    // Use lowProbCount = -1 for large blocks (>= 2048 sequences), 1 otherwise
2210    let use_low_prob_count = total >= 2048;
2211    let low_prob_count: i16 = if use_low_prob_count { -1 } else { 1 };
2212
2213    let mut still_to_distribute = table_size as i32;
2214    let mut largest_sym = 0usize;
2215    let mut largest_prob = 0i16;
2216
2217    for s in 0..=max_symbol {
2218        if counts[s] as u64 == total {
2219            // Single-symbol dominance
2220            norm[s] = table_size as i16;
2221            return norm;
2222        }
2223        if counts[s] == 0 {
2224            continue;
2225        }
2226
2227        if (counts[s] as u64) <= low_threshold {
2228            norm[s] = low_prob_count;
2229            still_to_distribute -= 1;
2230        } else {
2231            let mut proba = ((counts[s] as u64 * step) >> scale) as i16;
2232            if proba < 8 {
2233                // Use rtbTable for precise rounding
2234                let rest_to_beat = v_step as u128 * RTB_TABLE[proba as usize] as u128;
2235                let actual = (counts[s] as u128 * step as u128) - ((proba as u128) << scale);
2236                if actual > rest_to_beat {
2237                    proba += 1;
2238                }
2239            }
2240            if proba > (table_size >> 1) as i16 {
2241                proba = (table_size >> 1) as i16; // cap at half table
2242            }
2243            norm[s] = std::cmp::max(1, proba);
2244            still_to_distribute -= norm[s] as i32;
2245        }
2246
2247        if norm[s] > largest_prob {
2248            largest_prob = norm[s];
2249            largest_sym = s;
2250        }
2251    }
2252
2253    // Adjust largest symbol to distribute remaining
2254    if -still_to_distribute >= (norm[largest_sym] >> 1) as i32 {
2255        // Pathological case: use proportional redistribution
2256        normalize_counts_m2(&mut norm, counts, max_symbol, table_log, total);
2257    } else {
2258        norm[largest_sym] += still_to_distribute as i16;
2259    }
2260
2261    norm
2262}
2263
2264/// Fallback normalization for pathological distributions (port of FSE_normalizeM2).
2265fn normalize_counts_m2(
2266    norm: &mut [i16],
2267    counts: &[u32],
2268    max_symbol: usize,
2269    table_log: u32,
2270    total: u64,
2271) {
2272    let table_size = 1u32 << table_log;
2273
2274    // Reset and recalculate
2275    let mut to_distribute = table_size as i32;
2276
2277    // First pass: identify symbols that will get probability >= 1
2278    let low_one = (total * 3) / ((to_distribute as u64) * 2);
2279    for s in 0..=max_symbol {
2280        if counts[s] == 0 {
2281            norm[s] = 0;
2282        } else if (counts[s] as u64) <= low_one {
2283            norm[s] = -1;
2284            to_distribute -= 1;
2285        } else {
2286            norm[s] = 0; // will be set in second pass
2287        }
2288    }
2289
2290    // Second pass: proportional scaling for remaining symbols
2291    let remaining_total: u64 = counts[..=max_symbol]
2292        .iter()
2293        .enumerate()
2294        .filter(|&(s, _)| norm[s] == 0 && counts[s] > 0)
2295        .map(|(_, &c)| c as u64)
2296        .sum();
2297
2298    if remaining_total == 0 || to_distribute <= 0 {
2299        return;
2300    }
2301
2302    let v_step_log = 62u32.saturating_sub(table_log);
2303    let r_step = ((1u128 << v_step_log) * to_distribute as u128 + remaining_total as u128 / 2)
2304        / remaining_total as u128;
2305
2306    let mut tmp_total = 0u128;
2307    for s in 0..=max_symbol {
2308        if norm[s] == 0 && counts[s] > 0 {
2309            let end = tmp_total + counts[s] as u128 * r_step;
2310            let s_start = (tmp_total >> v_step_log) as i16;
2311            let s_end = (end >> v_step_log) as i16;
2312            let proba = s_end - s_start;
2313            norm[s] = std::cmp::max(1, proba);
2314            tmp_total = end;
2315        }
2316    }
2317}
2318
2319/// Encode an FSE probability header (the variable-bit format from the spec).
2320/// Returns the serialized header bytes.
2321fn encode_fse_header(norm: &[i16], max_symbol: usize, table_log: u32) -> Vec<u8> {
2322    let table_size = 1u32 << table_log;
2323    let mut bb: u64 = (table_log - 5) as u64; // accuracy_log = 5 + low4bits
2324    let mut bp = 4u32;
2325    let mut out = Vec::with_capacity(32);
2326    let mut counter = 0u32;
2327
2328    let mut s = 0usize;
2329    while s <= max_symbol && counter < table_size {
2330        let prob = norm[s] as i32;
2331        let value = (prob + 1) as u32;
2332
2333        let max_remaining = table_size - counter + 1;
2334        let bits_to_read = 32 - max_remaining.leading_zeros();
2335        let low_threshold = ((1u32 << bits_to_read) - 1) - max_remaining;
2336        let mask = (1u32 << (bits_to_read - 1)) - 1;
2337
2338        if value < low_threshold {
2339            bb |= (value as u64) << bp;
2340            bp += bits_to_read - 1;
2341        } else if value <= mask {
2342            bb |= (value as u64) << bp;
2343            bp += bits_to_read;
2344        } else {
2345            let encoded = value + low_threshold;
2346            bb |= (encoded as u64) << bp;
2347            bp += bits_to_read;
2348        }
2349
2350        while bp >= 8 {
2351            out.push(bb as u8);
2352            bb >>= 8;
2353            bp -= 8;
2354        }
2355
2356        if prob > 0 {
2357            counter += prob as u32;
2358        } else if prob == -1 {
2359            counter += 1;
2360        }
2361
2362        // Handle zero-probability repeat flags
2363        if prob == 0 {
2364            // Count consecutive zeros after this one
2365            let mut repeat = 0u32;
2366            while s + 1 + repeat as usize <= max_symbol
2367                && norm[s + 1 + repeat as usize] == 0
2368                && repeat < 3
2369            {
2370                repeat += 1;
2371            }
2372            bb |= (repeat as u64) << bp;
2373            bp += 2;
2374            while bp >= 8 {
2375                out.push(bb as u8);
2376                bb >>= 8;
2377                bp -= 8;
2378            }
2379            s += repeat as usize; // skip the zeros we just flagged
2380
2381            // If repeat == 3, keep emitting 2-bit repeat flags
2382            while repeat == 3 {
2383                repeat = 0;
2384                while s + 1 + repeat as usize <= max_symbol
2385                    && norm[s + 1 + repeat as usize] == 0
2386                    && repeat < 3
2387                {
2388                    repeat += 1;
2389                }
2390                bb |= (repeat as u64) << bp;
2391                bp += 2;
2392                while bp >= 8 {
2393                    out.push(bb as u8);
2394                    bb >>= 8;
2395                    bp -= 8;
2396                }
2397                s += repeat as usize;
2398            }
2399        }
2400
2401        s += 1;
2402    }
2403
2404    if bp > 0 {
2405        out.push(bb as u8);
2406    }
2407
2408    out
2409}
2410
2411/// Estimate the compressed size (in bits) of encoding `codes` with a given normalized distribution.
2412/// Cross-entropy cost of encoding `counts` using distribution `norm` at `table_log`.
2413/// Returns approximate total bits needed to encode all symbols.
2414fn cross_entropy_cost(norm: &[i16], table_log: u32, counts: &[u32; 256], max_sym: usize) -> u64 {
2415    let mut cost = 0u64;
2416    for s in 0..=max_sym {
2417        if counts[s] == 0 {
2418            continue;
2419        }
2420        if s >= norm.len() || norm[s] == 0 {
2421            return u64::MAX;
2422        }
2423        let prob = if norm[s] == -1 { 1u64 } else { norm[s] as u64 };
2424        // bits per symbol ≈ table_log - floor(log2(prob))
2425        let log2_prob = 63 - prob.leading_zeros() as u64;
2426        cost += counts[s] as u64 * (table_log as u64 - log2_prob);
2427    }
2428    cost + table_log as u64 // add state init cost
2429}
2430
2431/// Choose best mode considering Repeat from previous block.
2432fn choose_seq_mode(
2433    codes: &[u8],
2434    max_symbol_default: usize,
2435    default_log: u32,
2436    default_norm: &[i16],
2437    max_log: u32,
2438) -> SeqTableMode {
2439    if codes.is_empty() {
2440        return SeqTableMode::Predefined;
2441    }
2442
2443    // Count symbol frequencies
2444    let mut counts = [0u32; 256];
2445    let mut max_sym = 0usize;
2446    for &c in codes {
2447        counts[c as usize] += 1;
2448        if c as usize > max_sym {
2449            max_sym = c as usize;
2450        }
2451    }
2452
2453    let n_used = counts[..=max_sym].iter().filter(|&&c| c > 0).count();
2454
2455    // RLE: only one distinct symbol
2456    if n_used == 1 {
2457        let sym = codes[0];
2458        return SeqTableMode::Rle(sym);
2459    }
2460
2461    // Check if predefined table can represent all our symbols
2462    let predefined_ok = max_sym <= max_symbol_default
2463        && codes.iter().all(|&c| {
2464            let s = c as usize;
2465            s < default_norm.len() && default_norm[s] != 0
2466        });
2467
2468    // Try custom FSE table
2469    // Choose table_log: use max_log for best compression, but cap by number of symbols
2470    let table_log = {
2471        let min_log = 5u32;
2472        let symbol_log = if n_used <= 2 {
2473            min_log
2474        } else {
2475            std::cmp::min(max_log, (32 - (n_used as u32).leading_zeros()).max(min_log))
2476        };
2477        std::cmp::min(max_log, std::cmp::max(min_log, symbol_log))
2478    };
2479
2480    let custom_norm = normalize_counts(&counts, max_sym, table_log);
2481
2482    // Verify all symbols are covered
2483    let all_covered = codes.iter().all(|&c| {
2484        let s = c as usize;
2485        s <= max_sym && custom_norm[s] != 0
2486    });
2487
2488    if !all_covered {
2489        if predefined_ok {
2490            return SeqTableMode::Predefined;
2491        }
2492        // Last resort: can't encode
2493        return SeqTableMode::Predefined;
2494    }
2495
2496    let header_bytes = encode_fse_header(&custom_norm, max_sym, table_log);
2497
2498    // Bit-cost comparison (port of C zstd's ZSTD_selectEncodingType approach)
2499    let _nb_seq = codes.len();
2500
2501    // Cross-entropy cost for predefined table: sum of log2(tableSize/prob) per symbol
2502    let predefined_cost = if predefined_ok {
2503        cross_entropy_cost(default_norm, default_log, &counts, max_sym)
2504    } else {
2505        u64::MAX
2506    };
2507
2508    // Custom FSE cost: header bytes + cross-entropy with custom table
2509    let _custom_table_size = 1u64 << table_log;
2510    let mut custom_stream_cost = 0u64;
2511    for s in 0..=max_sym {
2512        if counts[s] > 0 {
2513            let prob = if custom_norm[s] == -1 {
2514                1u64
2515            } else {
2516                custom_norm[s] as u64
2517            };
2518            if prob == 0 {
2519                custom_stream_cost = u64::MAX;
2520                break;
2521            }
2522            // Cost in 256ths of a bit: count * log2(tableSize/prob) * 256
2523            // log2(tableSize/prob) = table_log - log2(prob)
2524            let log2_prob = 63 - prob.leading_zeros() as u64;
2525            custom_stream_cost += counts[s] as u64 * (table_log as u64 - log2_prob);
2526        }
2527    }
2528    let custom_header_cost = header_bytes.len() as u64 * 8;
2529    let custom_total_cost = custom_header_cost + custom_stream_cost + table_log as u64;
2530
2531    if predefined_ok && predefined_cost <= custom_total_cost {
2532        SeqTableMode::Predefined
2533    } else {
2534        SeqTableMode::Fse {
2535            norm: custom_norm,
2536            max_symbol: max_sym,
2537            table_log,
2538            header_bytes,
2539        }
2540    }
2541}
2542
2543/// Write the table description to `out` and return the built FSE compression table.
2544fn write_seq_table_and_build(
2545    out: &mut Vec<u8>,
2546    mode: &SeqTableMode,
2547    default_norm: &[i16],
2548    default_max_symbol: usize,
2549    default_log: u32,
2550) -> super::fse::FseCTable {
2551    match mode {
2552        SeqTableMode::Predefined => {
2553            super::fse::FseCTable::build(default_norm, default_max_symbol, default_log)
2554        }
2555        SeqTableMode::Rle(sym) => {
2556            out.push(*sym);
2557            super::fse::FseCTable::build_rle(*sym)
2558        }
2559        SeqTableMode::Fse {
2560            norm,
2561            max_symbol,
2562            table_log,
2563            header_bytes,
2564        } => {
2565            out.extend_from_slice(header_bytes);
2566            super::fse::FseCTable::build(norm, *max_symbol, *table_log)
2567        }
2568    }
2569}
2570
2571#[cfg(test)]
2572mod tests {
2573    use super::*;
2574
2575    #[test]
2576    fn compress_empty() {
2577        let compressed = compress(&[], 1);
2578        assert!(compressed.len() >= 5); // magic + header + empty block
2579        assert_eq!(&compressed[..4], &ZSTD_MAGIC.to_le_bytes());
2580    }
2581
2582    #[test]
2583    fn compress_small() {
2584        let data = b"hello world";
2585        let compressed = compress(data, 1);
2586        assert_eq!(&compressed[..4], &ZSTD_MAGIC.to_le_bytes());
2587        assert!(compressed.len() > 5);
2588    }
2589
2590    #[test]
2591    fn compress_repetitive() {
2592        let data = vec![42u8; 4096];
2593        let compressed = compress(&data, 1);
2594        // Valid zstd frame (raw blocks are larger than input due to framing)
2595        assert_eq!(&compressed[..4], &ZSTD_MAGIC.to_le_bytes());
2596    }
2597
2598    #[test]
2599    fn compress_real_data() {
2600        let data: Vec<u8> = (0..1024u32)
2601            .flat_map(|i| (i as f32).to_le_bytes())
2602            .collect();
2603        let compressed = compress(&data, 1);
2604        assert_eq!(&compressed[..4], &ZSTD_MAGIC.to_le_bytes());
2605    }
2606
2607    /// Golden test: roundtrip through our compressor → our decompressor.
2608    #[test]
2609    fn roundtrip_self_contained() {
2610        let test_cases: Vec<(&str, Vec<u8>)> = vec![
2611            ("zeros", vec![0u8; 4096]),
2612            (
2613                "sequential",
2614                (0..4096u32).flat_map(|i| i.to_le_bytes()).collect(),
2615            ),
2616            (
2617                "f32_data",
2618                (0..256u32)
2619                    .flat_map(|i| (i as f32 * 1.5).to_le_bytes())
2620                    .collect(),
2621            ),
2622            ("repetitive", b"hello world! ".repeat(100)),
2623            ("small", b"abc".to_vec()),
2624        ];
2625
2626        for (name, data) in &test_cases {
2627            let compressed = compress(data, 1);
2628            let decompressed = crate::decompress(&compressed)
2629                .unwrap_or_else(|e| panic!("{}: decompress failed: {}", name, e));
2630
2631            assert_eq!(decompressed.len(), data.len(), "{}: length mismatch", name);
2632            assert_eq!(&decompressed, data, "{}: data mismatch", name);
2633        }
2634    }
2635}