Skip to main content

haagenti_zstd/fse/
tans_encoder.rs

1//! Correct tANS (table-based Asymmetric Numeral Systems) encoder for Zstd sequences.
2//!
3//! This implements the exact FSE encoding algorithm from Zstd's reference implementation.
4//!
5//! ## Algorithm (from Zstd fse.h)
6//!
7//! ```text
8//! nbBitsOut = (state + deltaNbBits) >> 16
9//! output low nbBitsOut bits of state
10//! next_state = stateTable[(state >> nbBitsOut) + deltaFindState]
11//! ```
12//!
13//! Where:
14//! - `deltaNbBits = (maxBitsOut << 16) - minStatePlus`
15//! - `minStatePlus = prob << maxBitsOut`
16//! - `maxBitsOut = tableLog - highbit32(prob - 1)` for prob > 1
17//! - `deltaFindState = cumulative_total - prob`
18//!
19//! ## State Table
20//!
21//! The state table stores `tableSize + position` for each decode table position,
22//! ensuring encoder states are in [tableSize, 2*tableSize).
23//!
24//! ## Performance
25//!
26//! Uses `Arc<[T]>` for immutable table data, making Clone O(1) via reference
27//! counting instead of O(n) deep copies. Only the mutable `state` is copied.
28
29use super::table::FseTable;
30use std::sync::Arc;
31
32use super::{cloned_ll_encoder, cloned_ml_encoder, cloned_of_encoder};
33
34/// tANS encoding parameters for a single symbol.
35#[derive(Debug, Clone, Copy, Default)]
36pub struct TansSymbolParams {
37    /// Delta for computing number of bits to output.
38    /// Formula: nb_bits = (state + delta_nb_bits) >> 16
39    pub delta_nb_bits: u32,
40    /// Delta for finding next state in state table.
41    /// Formula: idx = (state >> nb_bits) + delta_find_state
42    pub delta_find_state: i32,
43}
44
45/// tANS encoder for a single stream.
46///
47/// Uses `Arc<[T]>` for immutable tables, making Clone O(1) instead of O(n).
48/// Only the mutable `state` field is copied on clone.
49#[derive(Debug, Clone)]
50pub struct TansEncoder {
51    /// Symbol encoding parameters (indexed by symbol).
52    /// Shared via Arc for O(1) clone.
53    symbol_params: Arc<[TansSymbolParams]>,
54    /// State table for finding next state.
55    /// Indexed by: (state >> nb_bits) + delta_find_state
56    /// Values are in [table_size, 2*table_size)
57    /// Shared via Arc for O(1) clone.
58    state_table: Arc<[u16]>,
59    /// Num bits to output for each decode state (indexed by decode_state = encoder_state - table_size).
60    /// Shared via Arc for O(1) clone.
61    #[allow(dead_code)]
62    num_bits_per_state: Arc<[u8]>,
63    /// Baseline for each decode state (for computing decoder's next state).
64    /// Shared via Arc for O(1) clone.
65    #[allow(dead_code)]
66    baseline_per_state: Arc<[u16]>,
67    /// Current encoder state (in [table_size, 2*table_size)).
68    state: u32,
69    /// Table size (1 << accuracy_log).
70    table_size: u32,
71    /// Accuracy log.
72    accuracy_log: u8,
73}
74
75impl TansEncoder {
76    /// Build a tANS encoder from a decode table.
77    ///
78    /// Implements the exact algorithm from Zstd's FSE_buildCTable_wksp.
79    pub fn from_decode_table(decode_table: &FseTable) -> Self {
80        let accuracy_log = decode_table.accuracy_log();
81        let table_size = decode_table.size() as u32;
82
83        // Count probability for each symbol from decode table
84        let mut symbol_probs = vec![0i32; 256];
85        for state in 0..table_size as usize {
86            let entry = decode_table.decode(state);
87            symbol_probs[entry.symbol as usize] += 1;
88        }
89
90        // Find highest symbol with non-zero probability
91        let max_symbol = symbol_probs
92            .iter()
93            .enumerate()
94            .rev()
95            .find(|&(_, &p)| p > 0)
96            .map(|(i, _)| i)
97            .unwrap_or(0);
98
99        // Compute cumulative probabilities
100        let mut cumul = vec![0u32; max_symbol + 2];
101        for i in 0..=max_symbol {
102            cumul[i + 1] = cumul[i] + symbol_probs[i].unsigned_abs();
103        }
104
105        // Build symbol encoding parameters (symbolTT in Zstd)
106        let mut symbol_params = vec![TansSymbolParams::default(); max_symbol + 1];
107        let mut total: u32 = 0;
108
109        for (symbol, &prob) in symbol_probs.iter().enumerate().take(max_symbol + 1) {
110            if prob == 0 {
111                // Symbol not present - use sentinel values
112                symbol_params[symbol] = TansSymbolParams {
113                    delta_nb_bits: ((accuracy_log as u32 + 1) << 16) - table_size,
114                    delta_find_state: 0,
115                };
116            } else if prob == 1 || prob == -1 {
117                // Special case for prob = 1 or -1 (less-than-one probability)
118                // deltaNbBits = (tableLog << 16) - (1 << tableLog)
119                symbol_params[symbol] = TansSymbolParams {
120                    delta_nb_bits: ((accuracy_log as u32) << 16).wrapping_sub(table_size),
121                    delta_find_state: total as i32 - 1,
122                };
123                total += 1;
124            } else {
125                // Normal case for prob > 1
126                // maxBitsOut = tableLog - highbit32(prob - 1)
127                let high_bit = 31 - (prob as u32 - 1).leading_zeros();
128                let max_bits_out = accuracy_log as u32 - high_bit;
129
130                // minStatePlus = prob << maxBitsOut
131                let min_state_plus = (prob as u32) << max_bits_out;
132
133                // deltaNbBits = (maxBitsOut << 16) - minStatePlus
134                let delta_nb_bits = (max_bits_out << 16).wrapping_sub(min_state_plus);
135
136                // deltaFindState = total - prob
137                let delta_find_state = total as i32 - prob;
138
139                symbol_params[symbol] = TansSymbolParams {
140                    delta_nb_bits,
141                    delta_find_state,
142                };
143                total += prob as u32;
144            }
145        }
146
147        // Build state table using decode table order
148        // stateTable[cumul[s]++] = tableSize + position
149        let mut state_table = vec![0u16; table_size as usize];
150        let mut cumul_copy = cumul.clone();
151
152        for position in 0..table_size as usize {
153            let symbol = decode_table.decode(position).symbol as usize;
154            if symbol <= max_symbol {
155                let idx = cumul_copy[symbol] as usize;
156                if idx < state_table.len() {
157                    // Store tableSize + position (state range: [tableSize, 2*tableSize))
158                    state_table[idx] = (table_size + position as u32) as u16;
159                    cumul_copy[symbol] += 1;
160                }
161            }
162        }
163
164        // Copy num_bits and baseline from decode table for each state
165        // These are used during encoding to output the correct number of bits
166        let mut num_bits_per_state = vec![0u8; table_size as usize];
167        let mut baseline_per_state = vec![0u16; table_size as usize];
168        for position in 0..table_size as usize {
169            let entry = decode_table.decode(position);
170            num_bits_per_state[position] = entry.num_bits;
171            baseline_per_state[position] = entry.baseline;
172        }
173
174        Self {
175            symbol_params: symbol_params.into(),
176            state_table: state_table.into(),
177            num_bits_per_state: num_bits_per_state.into(),
178            baseline_per_state: baseline_per_state.into(),
179            state: table_size, // Initial state
180            table_size,
181            accuracy_log,
182        }
183    }
184
185    /// Initialize encoder state for a specific symbol.
186    ///
187    /// Implements FSE_initCState2 from Zstd:
188    /// ```text
189    /// nbBitsOut = (deltaNbBits + (1<<15)) >> 16  // with rounding
190    /// value = (nbBitsOut << 16) - deltaNbBits
191    /// state = stateTable[(value >> nbBitsOut) + deltaFindState]
192    /// ```
193    /// This sets up the encoder for the first (last in sequence) symbol
194    /// without outputting any bits.
195    pub fn init_state(&mut self, symbol: u8) {
196        let sym_idx = symbol as usize;
197        if sym_idx >= self.symbol_params.len() {
198            self.state = self.table_size;
199            return;
200        }
201
202        let params = &self.symbol_params[sym_idx];
203
204        // FSE_initCState2 algorithm:
205        // 1. nbBitsOut = (deltaNbBits + (1<<15)) >> 16
206        let nb_bits_out = ((params.delta_nb_bits as u64 + 0x8000) >> 16) as u32;
207
208        // 2. value = (nbBitsOut << 16) - deltaNbBits
209        let value = ((nb_bits_out as u64) << 16).wrapping_sub(params.delta_nb_bits as u64) as u32;
210
211        // 3. state = stateTable[(value >> nbBitsOut) + deltaFindState]
212        let value_shifted = if nb_bits_out >= 32 {
213            0
214        } else {
215            value >> nb_bits_out
216        };
217        let idx = value_shifted as i64 + params.delta_find_state as i64;
218
219        if idx >= 0 && (idx as usize) < self.state_table.len() {
220            self.state = self.state_table[idx as usize] as u32;
221        } else {
222            // Fallback to tableSize
223            self.state = self.table_size;
224        }
225    }
226
227    /// Encode a symbol and return (bits, num_bits).
228    ///
229    /// Implements FSE_encodeSymbol from Zstd:
230    /// ```text
231    /// nbBitsOut = (state + deltaNbBits) >> 16
232    /// output low nbBitsOut bits of state
233    /// state = stateTable[(state >> nbBitsOut) + deltaFindState]
234    /// ```
235    #[inline]
236    pub fn encode_symbol(&mut self, symbol: u8) -> (u32, u8) {
237        let sym_idx = symbol as usize;
238
239        if sym_idx >= self.symbol_params.len() {
240            return (0, 0);
241        }
242
243        let params = &self.symbol_params[sym_idx];
244
245        // Step 1: nbBitsOut = (state + deltaNbBits) >> 16
246        // This is the core Zstd FSE formula
247        let nb_bits_out = ((self.state as u64 + params.delta_nb_bits as u64) >> 16) as u8;
248
249        // Step 2: Output low nbBitsOut bits of state
250        let bits_mask = if nb_bits_out >= 32 {
251            u32::MAX
252        } else {
253            (1u32 << nb_bits_out) - 1
254        };
255        let bits = self.state & bits_mask;
256
257        // Step 3: state = stateTable[(state >> nbBitsOut) + deltaFindState]
258        let state_shifted = if nb_bits_out >= 32 {
259            0
260        } else {
261            self.state >> nb_bits_out
262        };
263        let idx = state_shifted as i64 + params.delta_find_state as i64;
264
265        let next_state = if idx >= 0 && (idx as usize) < self.state_table.len() {
266            self.state_table[idx as usize] as u32
267        } else {
268            self.table_size
269        };
270
271        self.state = next_state;
272        (bits, nb_bits_out)
273    }
274
275    /// Get current state for serialization.
276    ///
277    /// The decoder reads this as the initial state.
278    /// For the bitstream, we write (state - tableSize) masked to accuracy_log bits.
279    #[inline]
280    pub fn get_state(&self) -> u32 {
281        // Return the decode state (0 to tableSize-1)
282        self.state.saturating_sub(self.table_size) & ((1 << self.accuracy_log) - 1)
283    }
284
285    /// Get accuracy log.
286    #[inline]
287    pub fn accuracy_log(&self) -> u8 {
288        self.accuracy_log
289    }
290
291    /// Reset encoder for new stream.
292    pub fn reset(&mut self) {
293        self.state = self.table_size;
294    }
295}
296
297/// Interleaved tANS encoder for Zstd sequences (LL, OF, ML).
298#[derive(Debug)]
299pub struct InterleavedTansEncoder {
300    ll_encoder: TansEncoder,
301    of_encoder: TansEncoder,
302    ml_encoder: TansEncoder,
303}
304
305impl InterleavedTansEncoder {
306    /// Create encoder from three FSE tables.
307    ///
308    /// Note: For predefined tables, prefer `new_predefined()` which uses
309    /// cached encoders for better performance.
310    pub fn new(ll_table: &FseTable, of_table: &FseTable, ml_table: &FseTable) -> Self {
311        Self {
312            ll_encoder: TansEncoder::from_decode_table(ll_table),
313            of_encoder: TansEncoder::from_decode_table(of_table),
314            ml_encoder: TansEncoder::from_decode_table(ml_table),
315        }
316    }
317
318    /// Create encoder using cached predefined tANS encoders.
319    ///
320    /// This is the fast path for encoding with predefined FSE tables.
321    /// Cloning cached encoders is much faster than building from scratch.
322    #[inline]
323    pub fn new_predefined() -> Self {
324        Self {
325            ll_encoder: cloned_ll_encoder(),
326            of_encoder: cloned_of_encoder(),
327            ml_encoder: cloned_ml_encoder(),
328        }
329    }
330
331    /// Create encoder from pre-built tANS encoders.
332    ///
333    /// This allows using custom FSE tables by building encoders separately.
334    pub fn from_encoders(
335        ll_encoder: TansEncoder,
336        of_encoder: TansEncoder,
337        ml_encoder: TansEncoder,
338    ) -> Self {
339        Self {
340            ll_encoder,
341            of_encoder,
342            ml_encoder,
343        }
344    }
345
346    /// Initialize all three encoders with their first symbols.
347    ///
348    /// These should be the LAST symbols in the sequence (encoding is reversed).
349    pub fn init_states(&mut self, ll: u8, of: u8, ml: u8) {
350        self.ll_encoder.init_state(ll);
351        self.of_encoder.init_state(of);
352        self.ml_encoder.init_state(ml);
353    }
354
355    /// Encode one sequence (LL, OF, ML) and return bits for each.
356    ///
357    /// Returns [(ll_bits, ll_nbits), (of_bits, of_nbits), (ml_bits, ml_nbits)]
358    ///
359    /// Note: tANS encoding order matters! We encode in reverse of decode order.
360    /// Decoder reads FSE updates: LL, ML, OF
361    /// So we encode: OF, ML, LL (reverse order for correct state transitions)
362    #[inline]
363    pub fn encode_sequence(&mut self, ll: u8, of: u8, ml: u8) -> [(u32, u8); 3] {
364        // Encode in reverse of decoder read order: OF, ML, LL
365        let of_bits = self.of_encoder.encode_symbol(of);
366        let ml_bits = self.ml_encoder.encode_symbol(ml);
367        let ll_bits = self.ll_encoder.encode_symbol(ll);
368        // Return in standard order for caller convenience
369        [ll_bits, of_bits, ml_bits]
370    }
371
372    /// Get final states for all three encoders.
373    ///
374    /// These become the decoder's initial states.
375    #[inline]
376    pub fn get_states(&self) -> (u32, u32, u32) {
377        (
378            self.ll_encoder.get_state(),
379            self.of_encoder.get_state(),
380            self.ml_encoder.get_state(),
381        )
382    }
383
384    /// Get accuracy logs for all three encoders.
385    #[inline]
386    pub fn accuracy_logs(&self) -> (u8, u8, u8) {
387        (
388            self.ll_encoder.accuracy_log(),
389            self.of_encoder.accuracy_log(),
390            self.ml_encoder.accuracy_log(),
391        )
392    }
393
394    /// Reset all encoders.
395    pub fn reset(&mut self) {
396        self.ll_encoder.reset();
397        self.of_encoder.reset();
398        self.ml_encoder.reset();
399    }
400}
401
402#[cfg(test)]
403mod tests {
404    use super::*;
405    use crate::fse::{
406        LITERAL_LENGTH_ACCURACY_LOG, LITERAL_LENGTH_DEFAULT_DISTRIBUTION,
407        MATCH_LENGTH_ACCURACY_LOG, MATCH_LENGTH_DEFAULT_DISTRIBUTION, OFFSET_ACCURACY_LOG,
408        OFFSET_DEFAULT_DISTRIBUTION,
409    };
410
411    #[test]
412    fn test_tans_encoder_creation() {
413        let table = FseTable::from_predefined(
414            &LITERAL_LENGTH_DEFAULT_DISTRIBUTION,
415            LITERAL_LENGTH_ACCURACY_LOG,
416        )
417        .unwrap();
418
419        let encoder = TansEncoder::from_decode_table(&table);
420        assert_eq!(encoder.accuracy_log(), LITERAL_LENGTH_ACCURACY_LOG);
421        assert_eq!(encoder.table_size, 1 << LITERAL_LENGTH_ACCURACY_LOG);
422    }
423
424    #[test]
425    fn test_tans_encoder_state_range() {
426        let table = FseTable::from_predefined(
427            &LITERAL_LENGTH_DEFAULT_DISTRIBUTION,
428            LITERAL_LENGTH_ACCURACY_LOG,
429        )
430        .unwrap();
431
432        let mut encoder = TansEncoder::from_decode_table(&table);
433        let table_size = encoder.table_size;
434
435        // Initialize with symbol 0
436        encoder.init_state(0);
437
438        // State should be in valid range [table_size, 2*table_size)
439        assert!(
440            encoder.state >= table_size,
441            "State {} should be >= table_size {}",
442            encoder.state,
443            table_size
444        );
445        assert!(
446            encoder.state < 2 * table_size,
447            "State {} should be < 2*table_size {}",
448            encoder.state,
449            2 * table_size
450        );
451    }
452
453    #[test]
454    fn test_tans_encoder_encode_symbol() {
455        let table = FseTable::from_predefined(
456            &LITERAL_LENGTH_DEFAULT_DISTRIBUTION,
457            LITERAL_LENGTH_ACCURACY_LOG,
458        )
459        .unwrap();
460
461        let mut encoder = TansEncoder::from_decode_table(&table);
462        encoder.init_state(0);
463
464        let table_size = encoder.table_size;
465
466        // Encode several symbols and verify state stays valid
467        for _ in 0..20 {
468            let (bits, num_bits) = encoder.encode_symbol(0);
469
470            // num_bits should be reasonable
471            assert!(
472                num_bits <= LITERAL_LENGTH_ACCURACY_LOG + 1,
473                "num_bits {} too large",
474                num_bits
475            );
476
477            // bits should fit in num_bits
478            if num_bits > 0 && num_bits < 32 {
479                assert!(
480                    bits < (1 << num_bits),
481                    "bits {} doesn't fit in {} bits",
482                    bits,
483                    num_bits
484                );
485            }
486
487            // State should remain valid
488            assert!(encoder.state >= table_size);
489            assert!(encoder.state < 2 * table_size);
490        }
491    }
492
493    #[test]
494    fn test_tans_encoder_all_symbols() {
495        let table = FseTable::from_predefined(
496            &LITERAL_LENGTH_DEFAULT_DISTRIBUTION,
497            LITERAL_LENGTH_ACCURACY_LOG,
498        )
499        .unwrap();
500
501        let mut encoder = TansEncoder::from_decode_table(&table);
502        let table_size = encoder.table_size;
503
504        // Test encoding each valid symbol
505        for symbol in 0..36u8 {
506            encoder.init_state(symbol);
507
508            let (bits, num_bits) = encoder.encode_symbol(symbol);
509
510            // Verify reasonable output
511            assert!(
512                num_bits <= LITERAL_LENGTH_ACCURACY_LOG + 1,
513                "Symbol {} produced {} bits",
514                symbol,
515                num_bits
516            );
517
518            // State should be valid after encoding
519            assert!(
520                encoder.state >= table_size,
521                "Symbol {} left state {} < table_size",
522                symbol,
523                encoder.state
524            );
525            assert!(
526                encoder.state < 2 * table_size,
527                "Symbol {} left state {} >= 2*table_size",
528                symbol,
529                encoder.state
530            );
531        }
532    }
533
534    #[test]
535    fn test_interleaved_encoder() {
536        let ll_table = FseTable::from_predefined(
537            &LITERAL_LENGTH_DEFAULT_DISTRIBUTION,
538            LITERAL_LENGTH_ACCURACY_LOG,
539        )
540        .unwrap();
541        let ml_table = FseTable::from_predefined(
542            &MATCH_LENGTH_DEFAULT_DISTRIBUTION,
543            MATCH_LENGTH_ACCURACY_LOG,
544        )
545        .unwrap();
546        let of_table =
547            FseTable::from_predefined(&OFFSET_DEFAULT_DISTRIBUTION, OFFSET_ACCURACY_LOG).unwrap();
548
549        let mut encoder = InterleavedTansEncoder::new(&ll_table, &of_table, &ml_table);
550        encoder.init_states(0, 0, 0);
551
552        // Encode a sequence
553        let [ll_bits, of_bits, ml_bits] = encoder.encode_sequence(0, 0, 0);
554
555        // All should produce valid outputs
556        assert!(ll_bits.1 <= LITERAL_LENGTH_ACCURACY_LOG + 1);
557        assert!(of_bits.1 <= OFFSET_ACCURACY_LOG + 1);
558        assert!(ml_bits.1 <= MATCH_LENGTH_ACCURACY_LOG + 1);
559
560        // Get final states
561        let (ll_state, of_state, ml_state) = encoder.get_states();
562
563        // States should be valid decode states (0 to table_size-1)
564        assert!(ll_state < (1 << LITERAL_LENGTH_ACCURACY_LOG));
565        assert!(of_state < (1 << OFFSET_ACCURACY_LOG));
566        assert!(ml_state < (1 << MATCH_LENGTH_ACCURACY_LOG));
567    }
568}
569
570#[cfg(test)]
571mod debug_tests {
572    use super::*;
573    use crate::fse::{
574        BitReader, FseBitWriter, FseDecoder, FseTable, LITERAL_LENGTH_ACCURACY_LOG,
575        LITERAL_LENGTH_DEFAULT_DISTRIBUTION, MATCH_LENGTH_ACCURACY_LOG,
576        MATCH_LENGTH_DEFAULT_DISTRIBUTION, OFFSET_ACCURACY_LOG, OFFSET_DEFAULT_DISTRIBUTION,
577    };
578
579    /// Test that our FSE bitstream exactly matches reference.
580    #[test]
581    fn test_build_exact_reference_bitstream() {
582        println!("=== Build Exact Reference Bitstream ===\n");
583
584        // Reference encodes: LL=4, OF=2, ML=41
585        // Reference states: LL=4, OF=14, ML=19
586        // Reference FSE bytes: [0xfd, 0xe4, 0x88]
587        // Reference extra bits: ll_extra=0 (0 bits), of_extra=3 (2 bits), ml_extra=13 (4 bits)
588
589        // Build tables
590        let ll_table = FseTable::from_predefined(
591            &LITERAL_LENGTH_DEFAULT_DISTRIBUTION,
592            LITERAL_LENGTH_ACCURACY_LOG,
593        )
594        .unwrap();
595        let of_table =
596            FseTable::from_predefined(&OFFSET_DEFAULT_DISTRIBUTION, OFFSET_ACCURACY_LOG).unwrap();
597        let ml_table = FseTable::from_predefined(
598            &MATCH_LENGTH_DEFAULT_DISTRIBUTION,
599            MATCH_LENGTH_ACCURACY_LOG,
600        )
601        .unwrap();
602
603        // Build interleaved encoder
604        let mut tans = InterleavedTansEncoder::new(&ll_table, &of_table, &ml_table);
605
606        // Initialize with the codes
607        let ll_code = 4u8;
608        let of_code = 2u8;
609        let ml_code = 41u8;
610        let of_extra = 3u32; // 2 bits
611        let of_bits = 2u8;
612        let ml_extra = 13u32; // 4 bits
613        let ml_bits = 4u8;
614        let ll_extra = 0u32; // 0 bits
615        let ll_bits = 0u8;
616
617        println!("Codes: LL={}, OF={}, ML={}", ll_code, of_code, ml_code);
618        println!(
619            "Extras: LL={}({} bits), OF={}({} bits), ML={}({} bits)",
620            ll_extra, ll_bits, of_extra, of_bits, ml_extra, ml_bits
621        );
622
623        // Init states
624        tans.init_states(ll_code, of_code, ml_code);
625        let (ll_state, of_state, ml_state) = tans.get_states();
626        println!(
627            "Init states: LL={}, OF={}, ML={}",
628            ll_state, of_state, ml_state
629        );
630
631        // Build bitstream exactly like build_fse_bitstream does for single sequence
632        let mut bits = FseBitWriter::new();
633
634        // 1. Write LAST sequence's extra bits first
635        //    Order: OF, ML, LL (reverse of read order)
636        println!("\nWriting extra bits:");
637        if of_bits > 0 {
638            println!("  OF extra: {} ({} bits)", of_extra, of_bits);
639            bits.write_bits(of_extra, of_bits);
640        }
641        if ml_bits > 0 {
642            println!("  ML extra: {} ({} bits)", ml_extra, ml_bits);
643            bits.write_bits(ml_extra, ml_bits);
644        }
645        if ll_bits > 0 {
646            println!("  LL extra: {} ({} bits)", ll_extra, ll_bits);
647            bits.write_bits(ll_extra, ll_bits);
648        }
649
650        // 2. No more sequences to process (single sequence)
651
652        // 3. Write final states: ML, OF, LL order
653        println!("\nWriting states:");
654        println!(
655            "  ML state: {} ({} bits)",
656            ml_state, MATCH_LENGTH_ACCURACY_LOG
657        );
658        bits.write_bits(ml_state, MATCH_LENGTH_ACCURACY_LOG);
659        println!("  OF state: {} ({} bits)", of_state, OFFSET_ACCURACY_LOG);
660        bits.write_bits(of_state, OFFSET_ACCURACY_LOG);
661        println!(
662            "  LL state: {} ({} bits)",
663            ll_state, LITERAL_LENGTH_ACCURACY_LOG
664        );
665        bits.write_bits(ll_state, LITERAL_LENGTH_ACCURACY_LOG);
666
667        let our_bitstream = bits.finish();
668        let ref_bitstream = [0xfd, 0xe4, 0x88];
669
670        println!("\nOur bitstream: {:02x?}", our_bitstream);
671        println!("Ref bitstream: {:02x?}", ref_bitstream);
672
673        // Analyze bit by bit
674        println!("\nBit comparison (LSB first):");
675        for i in 0..3 {
676            let our_byte = our_bitstream.get(i).copied().unwrap_or(0);
677            let ref_byte = ref_bitstream[i];
678            println!(
679                "  Byte {}: our={:08b}, ref={:08b}, diff={}",
680                i,
681                our_byte,
682                ref_byte,
683                if our_byte == ref_byte {
684                    "MATCH"
685                } else {
686                    "DIFFER"
687                }
688            );
689        }
690
691        // Analyze what's in the bits
692        // Total bits: 2 (of_extra) + 4 (ml_extra) + 0 (ll_extra) + 6 (ml_state) + 5 (of_state) + 6 (ll_state) = 23 bits
693        println!(
694            "\nTotal bits: {} + {} + {} + 6 + 5 + 6 = {} bits",
695            of_bits,
696            ml_bits,
697            ll_bits,
698            of_bits as usize + ml_bits as usize + ll_bits as usize + 17
699        );
700
701        // What reference produces:
702        // [fd, e4, 88] = [11111101, 11100100, 10001000] (binary, MSB first)
703        // Reading from the END (FSE backward reading):
704        // 0x88 = 10001000 - bit 7 is marker, bits 0-6 are data = 0001000 (7 bits available)
705        // After marker: we have 23 bits of data
706
707        // Analyze the bit-level difference
708        println!("\nBit-level analysis:");
709
710        // Our byte 0: 0xF7 = 11110111
711        // Ref byte 0: 0xFD = 11111101
712        // OF extra at bits 0-1: ours=11 (3), needs to verify ref
713        // ML extra at bits 2-5: ours=1011 (11 if read MSB-first from 5 down to 2)
714
715        // Let's verify what values we get when reading MSB-first vs LSB-first
716        println!(
717            "Our bits 2-5: {} {} {} {} = {} (LSB-first) or {} (MSB-first)",
718            (0xF7 >> 2) & 1,
719            (0xF7 >> 3) & 1,
720            (0xF7 >> 4) & 1,
721            (0xF7 >> 5) & 1,
722            (0xF7 >> 2) & 0xF, // LSB-first read
723            ((0xF7 >> 5) & 1) << 3
724                | ((0xF7 >> 4) & 1) << 2
725                | ((0xF7 >> 3) & 1) << 1
726                | ((0xF7 >> 2) & 1)  // MSB-first read
727        );
728
729        println!(
730            "Ref bits 2-5: {} {} {} {} = {} (LSB-first) or {} (MSB-first)",
731            (0xFD >> 2) & 1,
732            (0xFD >> 3) & 1,
733            (0xFD >> 4) & 1,
734            (0xFD >> 5) & 1,
735            (0xFD >> 2) & 0xF,
736            ((0xFD >> 5) & 1) << 3
737                | ((0xFD >> 4) & 1) << 2
738                | ((0xFD >> 3) & 1) << 1
739                | ((0xFD >> 2) & 1)
740        );
741
742        // The issue: we write 13 = 0b1101 LSB-first at bits 2-5
743        // This puts: bit2=1, bit3=0, bit4=1, bit5=1
744        // If read LSB-first: (bit5<<3)|(bit4<<2)|(bit3<<1)|bit2 = 8+4+0+1 = 13 ✓
745        // If read MSB-first from HIGH bit position: bit5,bit4,bit3,bit2 = 1,1,0,1 = 13 ✓
746
747        // Actually both should give 13! Let me verify what reference has...
748        // Reference expects ML_extra = 13. If ref bits 2-5 differ, what value does it encode?
749
750        // For decoder, it reads from bit position (not byte position) going DOWN
751        // After reading 17 bits (states), we're at some position in F7/FD
752        // Then we read ML_extra (4 bits)
753
754        // Let me decode reference to see what ML_extra it gives
755        let mut ref_bits = BitReader::new(&ref_bitstream);
756        ref_bits.init_from_end().unwrap();
757
758        // Read states (should match)
759        let ref_ll = ref_bits.read_bits(6).unwrap();
760        let ref_of = ref_bits.read_bits(5).unwrap();
761        let ref_ml = ref_bits.read_bits(6).unwrap();
762        println!(
763            "\nReference decoded states: LL={}, OF={}, ML={}",
764            ref_ll, ref_of, ref_ml
765        );
766
767        // Read extras
768        let ref_ll_extra = 0u32; // 0 bits for LL code 4
769        let ref_ml_extra = ref_bits.read_bits(4).unwrap();
770        let ref_of_extra = ref_bits.read_bits(2).unwrap();
771        println!(
772            "Reference decoded extras: LL_extra={}, ML_extra={}, OF_extra={}",
773            ref_ll_extra, ref_ml_extra, ref_of_extra
774        );
775
776        // Now decode our bitstream
777        let mut our_bits = BitReader::new(&our_bitstream);
778        our_bits.init_from_end().unwrap();
779
780        let our_ll = our_bits.read_bits(6).unwrap();
781        let our_of = our_bits.read_bits(5).unwrap();
782        let our_ml = our_bits.read_bits(6).unwrap();
783        println!(
784            "\nOur decoded states: LL={}, OF={}, ML={}",
785            our_ll, our_of, our_ml
786        );
787
788        let our_ll_extra = 0u32;
789        let our_ml_extra = our_bits.read_bits(4).unwrap();
790        let our_of_extra = our_bits.read_bits(2).unwrap();
791        println!(
792            "Our decoded extras: LL_extra={}, ML_extra={}, OF_extra={}",
793            our_ll_extra, our_ml_extra, our_of_extra
794        );
795
796        // The real question: does reference have different extra values, or is
797        // the bit reading/writing order different?
798    }
799
800    /// Debug trace of bit reading from reference FSE bytes
801    #[test]
802    fn test_trace_bit_reading() {
803        println!("=== Tracing Bit Reading from Reference FSE Bytes ===\n");
804
805        let fse_bytes = [0xfd, 0xe4, 0x88];
806        println!("Bytes: {:02x?}", fse_bytes);
807        println!("Binary:");
808        println!("  0xFD = {:08b} (bits 0-7)", 0xFD);
809        println!("  0xE4 = {:08b} (bits 8-15)", 0xE4);
810        println!("  0x88 = {:08b} (bits 16-23)", 0x88);
811
812        // Initialize BitReader in reversed mode
813        let mut bits = BitReader::new(&fse_bytes);
814        bits.init_from_end().unwrap();
815        println!("\nBits available after init: {}", bits.bits_remaining());
816
817        // Read states
818        let ll_state = bits.read_bits(6).unwrap();
819        println!("\nRead LL state (6 bits): {} (expect 4)", ll_state);
820        println!("  Bits remaining: {}", bits.bits_remaining());
821
822        let of_state = bits.read_bits(5).unwrap();
823        println!("Read OF state (5 bits): {} (expect 14)", of_state);
824        println!("  Bits remaining: {}", bits.bits_remaining());
825
826        let ml_state = bits.read_bits(6).unwrap();
827        println!("Read ML state (6 bits): {} (expect 19)", ml_state);
828        println!("  Bits remaining: {}", bits.bits_remaining());
829
830        // Switch to LSB-first mode for reading extra bits
831        // (Extra bits are at the beginning of the bitstream, read from bit 0 up)
832        bits.switch_to_lsb_mode().unwrap();
833
834        // Read extras (in sequence order: LL, ML, OF)
835        // LL code 4 has 0 extra bits
836        // ML code 41 has 4 extra bits
837        // OF code 2 has 2 extra bits
838
839        let ll_extra = 0u32; // 0 bits for LL code 4
840        println!(
841            "\nRead LL extra (0 bits): {} (no extra for code 4)",
842            ll_extra
843        );
844        println!("  Bits remaining: {}", bits.bits_remaining());
845
846        let ml_extra = bits.read_bits(4).unwrap();
847        println!("Read ML extra (4 bits): {} (expect 13)", ml_extra);
848        println!("  Bits remaining: {}", bits.bits_remaining());
849
850        let of_extra = bits.read_bits(2).unwrap();
851        println!("Read OF extra (2 bits): {} (expect 3)", of_extra);
852        println!("  Bits remaining: {}", bits.bits_remaining());
853
854        // Verify values
855        assert_eq!(ll_state, 4, "LL state mismatch");
856        assert_eq!(of_state, 14, "OF state mismatch");
857        assert_eq!(ml_state, 19, "ML state mismatch");
858        assert_eq!(ml_extra, 13, "ML extra mismatch - THIS IS THE BUG!");
859        assert_eq!(of_extra, 3, "OF extra mismatch");
860
861        // Calculate what match_length we get
862        let match_length = 83 + ml_extra; // baseline 83 for ML code 41
863        println!("\nMatch length: 83 + {} = {}", ml_extra, match_length);
864        println!(
865            "Total bytes: 4 (literals) + {} (match) = {}",
866            match_length,
867            4 + match_length
868        );
869    }
870
871    /// Test full reference frame decompression
872    #[test]
873    fn test_full_reference_frame_decode() {
874        // Reference zstd -1 --no-check of "ABCD"x25
875        let ref_frame: [u8; 19] = [
876            0x28, 0xb5, 0x2f, 0xfd, // magic
877            0x20, // FHD: Single_Segment=1, no checksum
878            0x64, // content size = 100
879            0x55, 0x00, 0x00, // block header
880            0x20, // literals header
881            0x41, 0x42, 0x43, 0x44, // literals "ABCD"
882            0x01, // 1 sequence
883            0x00, // mode byte (predefined tables)
884            0xfd, 0xe4, 0x88, // FSE bitstream
885        ];
886
887        // Decompress with our decoder
888        let decompressed = crate::decompress::decompress_frame(&ref_frame)
889            .expect("Failed to decompress reference frame");
890        let expected = "ABCD".repeat(25);
891
892        println!("Decompressed length: {}", decompressed.len());
893        println!("Expected length: {}", expected.len());
894        println!(
895            "First 20 bytes: {:?}",
896            &decompressed[..20.min(decompressed.len())]
897        );
898
899        assert_eq!(decompressed.len(), 100, "Length mismatch");
900        assert_eq!(decompressed, expected.as_bytes(), "Content mismatch");
901        println!("Reference frame decompression verified!");
902    }
903
904    /// Debug test to decode reference FSE bytes and understand what they encode.
905    #[test]
906    fn test_decode_reference_fse_bytes() {
907        // Reference FSE bytes from zstd -1 compression of "ABCD"x25
908        let fse_bytes = [0xfd, 0xe4, 0x88];
909
910        println!("=== Decoding Reference FSE Bytes ===");
911        println!("Bytes: {:02x?}", fse_bytes);
912
913        // Build predefined tables
914        let ll_table = FseTable::from_predefined(
915            &LITERAL_LENGTH_DEFAULT_DISTRIBUTION,
916            LITERAL_LENGTH_ACCURACY_LOG,
917        )
918        .unwrap();
919        let of_table =
920            FseTable::from_predefined(&OFFSET_DEFAULT_DISTRIBUTION, OFFSET_ACCURACY_LOG).unwrap();
921        let ml_table = FseTable::from_predefined(
922            &MATCH_LENGTH_DEFAULT_DISTRIBUTION,
923            MATCH_LENGTH_ACCURACY_LOG,
924        )
925        .unwrap();
926
927        // Create decoders
928        let mut ll_decoder = FseDecoder::new(&ll_table);
929        let mut of_decoder = FseDecoder::new(&of_table);
930        let mut ml_decoder = FseDecoder::new(&ml_table);
931
932        // Init bit reader from end
933        let mut bits = BitReader::new(&fse_bytes);
934        bits.init_from_end().unwrap();
935        println!("Bits available after init: {}", bits.bits_remaining());
936
937        // Read initial states (LL, OF, ML order)
938        ll_decoder.init_state(&mut bits).unwrap();
939        of_decoder.init_state(&mut bits).unwrap();
940        ml_decoder.init_state(&mut bits).unwrap();
941
942        let ll_state = ll_decoder.state();
943        let of_state = of_decoder.state();
944        let ml_state = ml_decoder.state();
945
946        println!(
947            "Initial states: LL={}, OF={}, ML={}",
948            ll_state, of_state, ml_state
949        );
950        println!("Bits remaining after states: {}", bits.bits_remaining());
951
952        // Get symbols from states
953        let ll_code = ll_table.decode(ll_state).symbol;
954        let of_code = of_table.decode(of_state).symbol;
955        let ml_code = ml_table.decode(ml_state).symbol;
956
957        println!("Symbols from states:");
958        println!("  LL code {} (from state {})", ll_code, ll_state);
959        println!("  OF code {} (from state {})", of_code, of_state);
960        println!("  ML code {} (from state {})", ml_code, ml_state);
961
962        // Print what these codes mean:
963        println!("\nCode meanings:");
964
965        // LL code interpretation
966        if ll_code <= 15 {
967            println!(
968                "  LL code {}: literal_length = {} (no extra bits)",
969                ll_code, ll_code
970            );
971        } else {
972            let extra_bits = match ll_code {
973                16..=17 => 1,
974                18..=19 => 1,
975                20..=21 => 2,
976                22..=23 => 3,
977                24..=25 => 4,
978                26..=27 => 5,
979                28..=29 => 6,
980                30..=31 => 7,
981                32..=33 => 8,
982                34..=35 => 9,
983                _ => 0,
984            };
985            println!("  LL code {}: needs {} extra bits", ll_code, extra_bits);
986        }
987
988        // OF code = offset code, number of extra bits = of_code
989        println!(
990            "  OF code {}: offset = 2^{} + {} extra bits",
991            of_code, of_code, of_code
992        );
993
994        // ML code interpretation
995        if ml_code <= 31 {
996            println!(
997                "  ML code {}: match_length = {} (no extra bits)",
998                ml_code,
999                ml_code + 3
1000            );
1001        } else {
1002            println!("  ML code {}: needs extra bits", ml_code);
1003        }
1004
1005        // Read remaining extra bits
1006        let remaining = bits.bits_remaining();
1007        println!("\nRemaining bits for extras: {}", remaining);
1008
1009        // For a single sequence with predefined tables, the extra bits should be:
1010        // LL extra, ML extra, OF extra (in read order)
1011        // But we need to know how many bits each needs
1012    }
1013
1014    /// Debug test to trace init_state calculation.
1015    #[test]
1016    fn test_trace_init_state() {
1017        println!("=== Tracing init_state calculation ===\n");
1018
1019        // Build LL table and encoder
1020        let ll_table = FseTable::from_predefined(
1021            &LITERAL_LENGTH_DEFAULT_DISTRIBUTION,
1022            LITERAL_LENGTH_ACCURACY_LOG,
1023        )
1024        .unwrap();
1025
1026        let encoder = TansEncoder::from_decode_table(&ll_table);
1027
1028        // Print which states decode to symbol 0
1029        println!("States that decode to LL symbol 0:");
1030        for state in 0..64 {
1031            let entry = ll_table.decode(state);
1032            if entry.symbol == 0 {
1033                println!(
1034                    "  State {}: symbol={}, num_bits={}, baseline={}",
1035                    state, entry.symbol, entry.num_bits, entry.baseline
1036                );
1037            }
1038        }
1039
1040        // Print symbol params for symbol 0
1041        let params = &encoder.symbol_params[0];
1042        println!("\nSymbol 0 params:");
1043        println!(
1044            "  delta_nb_bits: {} (0x{:x})",
1045            params.delta_nb_bits, params.delta_nb_bits
1046        );
1047        println!("  delta_find_state: {}", params.delta_find_state);
1048
1049        // Trace init_state calculation
1050        let sym_idx = 0usize;
1051        let nb_bits_out = ((params.delta_nb_bits as u64 + 0x8000) >> 16) as u32;
1052        let value = ((nb_bits_out as u64) << 16).wrapping_sub(params.delta_nb_bits as u64) as u32;
1053        let value_shifted = if nb_bits_out >= 32 {
1054            0
1055        } else {
1056            value >> nb_bits_out
1057        };
1058        let idx = value_shifted as i64 + params.delta_find_state as i64;
1059
1060        println!("\ninit_state(0) calculation:");
1061        println!(
1062            "  nb_bits_out = ({} + 0x8000) >> 16 = {}",
1063            params.delta_nb_bits, nb_bits_out
1064        );
1065        println!(
1066            "  value = ({} << 16) - {} = {}",
1067            nb_bits_out, params.delta_nb_bits, value
1068        );
1069        println!(
1070            "  value_shifted = {} >> {} = {}",
1071            value, nb_bits_out, value_shifted
1072        );
1073        println!(
1074            "  idx = {} + {} = {}",
1075            value_shifted, params.delta_find_state, idx
1076        );
1077        println!(
1078            "  state_table[{}] = {}",
1079            idx, encoder.state_table[idx as usize]
1080        );
1081        println!(
1082            "  Final decode_state = {} - 64 = {}",
1083            encoder.state_table[idx as usize],
1084            encoder.state_table[idx as usize] as i32 - 64
1085        );
1086
1087        // What init state does our encoder produce?
1088        let mut test_encoder = TansEncoder::from_decode_table(&ll_table);
1089        test_encoder.init_state(0);
1090        let our_state = test_encoder.get_state();
1091        println!("\nOur init_state(0) produces decode_state: {}", our_state);
1092
1093        // Verify it decodes to symbol 0
1094        let entry = ll_table.decode(our_state as usize);
1095        println!("State {} decodes to symbol {}", our_state, entry.symbol);
1096
1097        // What's at state 38 (reference)?
1098        let ref_entry = ll_table.decode(38);
1099        println!(
1100            "\nReference state 38 decodes to symbol {}",
1101            ref_entry.symbol
1102        );
1103    }
1104
1105    /// Test init_state for the specific codes used by reference.
1106    #[test]
1107    fn test_init_state_for_reference_codes() {
1108        println!("=== Init State for Reference Codes ===\n");
1109
1110        // Reference encodes: LL=4, OF=2, ML=41
1111        // Reference states: LL=4, OF=14, ML=19
1112
1113        // Build tables
1114        let ll_table = FseTable::from_predefined(
1115            &LITERAL_LENGTH_DEFAULT_DISTRIBUTION,
1116            LITERAL_LENGTH_ACCURACY_LOG,
1117        )
1118        .unwrap();
1119        let of_table =
1120            FseTable::from_predefined(&OFFSET_DEFAULT_DISTRIBUTION, OFFSET_ACCURACY_LOG).unwrap();
1121        let ml_table = FseTable::from_predefined(
1122            &MATCH_LENGTH_DEFAULT_DISTRIBUTION,
1123            MATCH_LENGTH_ACCURACY_LOG,
1124        )
1125        .unwrap();
1126
1127        // Build encoders
1128        let mut ll_encoder = TansEncoder::from_decode_table(&ll_table);
1129        let mut of_encoder = TansEncoder::from_decode_table(&of_table);
1130        let mut ml_encoder = TansEncoder::from_decode_table(&ml_table);
1131
1132        // Test LL code 4
1133        ll_encoder.init_state(4);
1134        let our_ll_state = ll_encoder.get_state();
1135        let ref_ll_state = 4u32;
1136        let ll_entry = ll_table.decode(our_ll_state as usize);
1137        println!("LL code 4:");
1138        println!("  Reference state: {}", ref_ll_state);
1139        println!("  Our state: {}", our_ll_state);
1140        println!("  Our state decodes to symbol: {}", ll_entry.symbol);
1141        println!("  Match: {}", our_ll_state == ref_ll_state);
1142
1143        // Test OF code 1 (repeat offset 2 = 4)
1144        of_encoder.init_state(1);
1145        let of_state_for_code1 = of_encoder.get_state();
1146        println!("\nOF code 1 (repeat offset 2 = 4):");
1147        println!("  Our state: {}", of_state_for_code1);
1148        println!(
1149            "  Decodes to symbol: {}",
1150            of_table.decode(of_state_for_code1 as usize).symbol
1151        );
1152
1153        // Test OF code 2
1154        of_encoder.init_state(2);
1155        let our_of_state = of_encoder.get_state();
1156        let ref_of_state = 14u32;
1157        let of_entry = of_table.decode(our_of_state as usize);
1158        println!("\nOF code 2:");
1159        println!("  Reference state: {}", ref_of_state);
1160        println!("  Our state: {}", our_of_state);
1161        println!("  Our state decodes to symbol: {}", of_entry.symbol);
1162        println!("  Match: {}", our_of_state == ref_of_state);
1163
1164        // Test ML code 41
1165        ml_encoder.init_state(41);
1166        let our_ml_state = ml_encoder.get_state();
1167        let ref_ml_state = 19u32;
1168        let ml_entry = ml_table.decode(our_ml_state as usize);
1169        println!("\nML code 41:");
1170        println!("  Reference state: {}", ref_ml_state);
1171        println!("  Our state: {}", our_ml_state);
1172        println!("  Our state decodes to symbol: {}", ml_entry.symbol);
1173        println!("  Match: {}", our_ml_state == ref_ml_state);
1174
1175        // Print which states decode to each symbol
1176        println!("\n--- States that decode to symbol 4 in LL table ---");
1177        for state in 0..64 {
1178            if ll_table.decode(state).symbol == 4 {
1179                println!("  State {}", state);
1180            }
1181        }
1182
1183        println!("\n--- Full OF table (state -> symbol) ---");
1184        for state in 0..32 {
1185            let entry = of_table.decode(state);
1186            println!(
1187                "  State {:2} -> symbol {:2} (num_bits={}, baseline={})",
1188                state, entry.symbol, entry.num_bits, entry.baseline
1189            );
1190        }
1191
1192        println!("\n--- States that decode to symbol 1 in OF table (for offset 4) ---");
1193        for state in 0..32 {
1194            if of_table.decode(state).symbol == 1 {
1195                println!("  State {}", state);
1196            }
1197        }
1198
1199        println!("\n--- States that decode to symbol 5 in OF table (for offset 4-7) ---");
1200        for state in 0..32 {
1201            if of_table.decode(state).symbol == 5 {
1202                println!("  State {}", state);
1203            }
1204        }
1205
1206        println!("\n--- States that decode to symbol 41 in ML table ---");
1207        for state in 0..64 {
1208            if ml_table.decode(state).symbol == 41 {
1209                println!("  State {}", state);
1210            }
1211        }
1212
1213        // Assert matches
1214        assert_eq!(our_ll_state, ref_ll_state, "LL state mismatch");
1215        assert_eq!(our_of_state, ref_of_state, "OF state mismatch");
1216        assert_eq!(our_ml_state, ref_ml_state, "ML state mismatch");
1217    }
1218
1219    /// Debug: print state_table construction
1220    #[test]
1221    fn test_state_table_construction() {
1222        println!("=== State Table Construction ===\n");
1223
1224        let ll_table = FseTable::from_predefined(
1225            &LITERAL_LENGTH_DEFAULT_DISTRIBUTION,
1226            LITERAL_LENGTH_ACCURACY_LOG,
1227        )
1228        .unwrap();
1229
1230        let encoder = TansEncoder::from_decode_table(&ll_table);
1231
1232        // Print first 20 entries of state_table
1233        println!("state_table (first 20 entries):");
1234        for i in 0..20 {
1235            let encoder_state = encoder.state_table[i];
1236            let decode_state = encoder_state as i32 - 64;
1237            let entry = ll_table.decode(decode_state as usize);
1238            println!(
1239                "  state_table[{:2}] = {} (decode_state={}, symbol={})",
1240                i, encoder_state, decode_state, entry.symbol
1241            );
1242        }
1243
1244        // Print symbol params
1245        println!("\nSymbol params (first 10 symbols):");
1246        for sym in 0..10 {
1247            if sym < encoder.symbol_params.len() {
1248                let params = &encoder.symbol_params[sym];
1249                println!(
1250                    "  Symbol {:2}: delta_nb_bits={:6}, delta_find_state={:3}",
1251                    sym, params.delta_nb_bits, params.delta_find_state
1252                );
1253            }
1254        }
1255    }
1256
1257    #[test]
1258    fn test_tans_encode_decode_roundtrip() {
1259        let table = FseTable::from_predefined(
1260            &LITERAL_LENGTH_DEFAULT_DISTRIBUTION,
1261            LITERAL_LENGTH_ACCURACY_LOG,
1262        )
1263        .unwrap();
1264
1265        let mut encoder = TansEncoder::from_decode_table(&table);
1266        let accuracy_log = encoder.accuracy_log;
1267
1268        // Encode a simple sequence: [0, 0, 0]
1269        let symbols = [0u8, 0, 0];
1270
1271        println!("symbol_params len: {}", encoder.symbol_params.len());
1272        println!(
1273            "num_bits_per_state len: {}",
1274            encoder.num_bits_per_state.len()
1275        );
1276        println!(
1277            "baseline_per_state len: {}",
1278            encoder.baseline_per_state.len()
1279        );
1280
1281        // Debug: print symbol params for symbol 0
1282        let params0 = &encoder.symbol_params[0];
1283        println!(
1284            "Symbol 0 params: delta_nb_bits={}, delta_find_state={}",
1285            params0.delta_nb_bits, params0.delta_find_state
1286        );
1287        println!(
1288            "  Expected nbBitsOut at state 64: (64 + {}) >> 16 = {}",
1289            params0.delta_nb_bits,
1290            (64u64 + params0.delta_nb_bits as u64) >> 16
1291        );
1292
1293        // Print decode table for first few states
1294        println!("Decode table:");
1295        for s in 0..4 {
1296            let entry = table.decode(s);
1297            println!(
1298                "  state {}: symbol={}, num_bits={}, baseline={}",
1299                s, entry.symbol, entry.num_bits, entry.baseline
1300            );
1301        }
1302
1303        // Initialize with last symbol (no bits output)
1304        encoder.init_state(symbols[2]);
1305        let init_state = encoder.state;
1306        println!(
1307            "After init_state(0): encoder_state={}, decode_state={}",
1308            init_state,
1309            init_state.saturating_sub(64)
1310        );
1311
1312        // Collect bits from encoding remaining symbols in reverse
1313        let mut all_bits: Vec<(u32, u8)> = Vec::new();
1314        for &sym in symbols[..2].iter().rev() {
1315            let old_state = encoder.state;
1316            let old_decode = old_state.saturating_sub(64);
1317            println!(
1318                "Before encode sym={}: encoder_state={}, decode_state={}",
1319                sym, old_state, old_decode
1320            );
1321            let (bits, nb) = encoder.encode_symbol(sym);
1322            let new_decode = encoder.state.saturating_sub(64);
1323            println!(
1324                "After encode: bits={}, nb_bits={}, new_decode_state={}",
1325                bits, nb, new_decode
1326            );
1327            all_bits.push((bits, nb));
1328        }
1329
1330        // Get final state for decoder init
1331        let final_state = encoder.get_state();
1332
1333        // Build bitstream: bits in forward order, then final state
1334        // Reader reads backwards, so forward write order gives correct read order
1335        let mut writer = FseBitWriter::new();
1336        for (bits, nb) in all_bits.iter() {
1337            writer.write_bits(*bits, *nb);
1338        }
1339        writer.write_bits(final_state, accuracy_log);
1340        let bitstream = writer.finish();
1341
1342        println!("Encoded sequence {:?}", symbols);
1343        println!("Init state: {}, Final state: {}", init_state, final_state);
1344        println!("Bits: {:?}", all_bits);
1345        println!("Bitstream ({} bytes): {:?}", bitstream.len(), bitstream);
1346
1347        // Decode
1348        let mut decoder = FseDecoder::new(&table);
1349        let mut bits_reader = BitReader::new(&bitstream);
1350        bits_reader.init_from_end().unwrap();
1351        println!(
1352            "Bits remaining after init_from_end: {}",
1353            bits_reader.bits_remaining()
1354        );
1355
1356        // Read initial state
1357        decoder.init_state(&mut bits_reader).unwrap();
1358        println!("Decoder initial state: {}", decoder.state());
1359        println!(
1360            "Bits remaining after init_state: {}",
1361            bits_reader.bits_remaining()
1362        );
1363
1364        // Decode symbols
1365        // Note: For N symbols, we encode N-1 (the first is just initialized)
1366        // So we decode N-1 symbols that read bits, then peek the last symbol
1367        let mut decoded = Vec::new();
1368
1369        // Decode first N-1 symbols (these read bits)
1370        for i in 0..2 {
1371            let entry = table.decode(decoder.state());
1372            println!(
1373                "Before decode[{}]: state={}, needs {} bits, bits_remaining={}",
1374                i,
1375                decoder.state(),
1376                entry.num_bits,
1377                bits_reader.bits_remaining()
1378            );
1379
1380            let sym = decoder.decode_symbol(&mut bits_reader).unwrap();
1381            decoded.push(sym);
1382            println!("Decoded: {}, new state: {}", sym, decoder.state());
1383        }
1384
1385        // Last symbol: just peek, don't read bits (this was the initialized one)
1386        let last_sym = decoder.peek_symbol();
1387        decoded.push(last_sym);
1388        println!("Last symbol (peek): {}", last_sym);
1389
1390        println!("Decoded sequence: {:?}", decoded);
1391        assert_eq!(
1392            decoded,
1393            symbols.to_vec(),
1394            "Decoded sequence doesn't match original"
1395        );
1396    }
1397
1398    #[test]
1399    fn test_tans_mixed_symbols_roundtrip() {
1400        let table = FseTable::from_predefined(
1401            &LITERAL_LENGTH_DEFAULT_DISTRIBUTION,
1402            LITERAL_LENGTH_ACCURACY_LOG,
1403        )
1404        .unwrap();
1405
1406        // Debug: print decode table for first 40 states
1407        println!("Decode table (first 40 states):");
1408        for s in 0..40 {
1409            let entry = table.decode(s);
1410            println!(
1411                "  state {:2}: symbol={:2}, num_bits={}, baseline={:2}",
1412                s, entry.symbol, entry.num_bits, entry.baseline
1413            );
1414        }
1415
1416        let mut encoder = TansEncoder::from_decode_table(&table);
1417        let accuracy_log = encoder.accuracy_log;
1418
1419        // Encode a sequence with different symbols: [0, 1, 2, 0, 1]
1420        let symbols = [0u8, 1, 2, 0, 1];
1421
1422        println!("\nEncoding symbols: {:?}", symbols);
1423
1424        // Initialize with last symbol
1425        encoder.init_state(symbols[4]);
1426        println!("After init_state({}): state={}", symbols[4], encoder.state);
1427
1428        // Encode remaining in reverse
1429        let mut all_bits: Vec<(u32, u8)> = Vec::new();
1430        for &sym in symbols[..4].iter().rev() {
1431            let (bits, nb) = encoder.encode_symbol(sym);
1432            println!(
1433                "Encode sym={}: bits={}, nb_bits={}, new_state={}",
1434                sym, bits, nb, encoder.state
1435            );
1436            all_bits.push((bits, nb));
1437        }
1438
1439        let final_state = encoder.get_state();
1440        println!("Final state: {}", final_state);
1441
1442        // Build bitstream
1443        // Bits are read backwards, so write in forward order: B4, B3, B2, B1, then D_0
1444        // This way reader gets: D_0, B1, B2, B3, B4 (correct order for decoding)
1445        let mut writer = FseBitWriter::new();
1446        for (bits, nb) in all_bits.iter() {
1447            // Forward order: B4, B3, B2, B1
1448            writer.write_bits(*bits, *nb);
1449        }
1450        writer.write_bits(final_state, accuracy_log);
1451        let bitstream = writer.finish();
1452
1453        println!("Bitstream ({} bytes): {:?}", bitstream.len(), bitstream);
1454
1455        // Decode
1456        let mut decoder = FseDecoder::new(&table);
1457        let mut bits_reader = BitReader::new(&bitstream);
1458        bits_reader.init_from_end().unwrap();
1459
1460        decoder.init_state(&mut bits_reader).unwrap();
1461        println!("Decoder initial state: {}", decoder.state());
1462
1463        let mut decoded = Vec::new();
1464        for _ in 0..4 {
1465            let sym = decoder.decode_symbol(&mut bits_reader).unwrap();
1466            decoded.push(sym);
1467            println!("Decoded: {}, new state: {}", sym, decoder.state());
1468        }
1469        let last_sym = decoder.peek_symbol();
1470        decoded.push(last_sym);
1471
1472        println!("Decoded sequence: {:?}", decoded);
1473        assert_eq!(
1474            decoded,
1475            symbols.to_vec(),
1476            "Decoded sequence doesn't match original"
1477        );
1478    }
1479
1480    #[test]
1481    fn test_ml_codes_38_and_43() {
1482        println!("\n=== ML Codes 38 and 43 State Mapping ===");
1483
1484        // Build ML table
1485        let ml_table = FseTable::from_predefined(
1486            &MATCH_LENGTH_DEFAULT_DISTRIBUTION,
1487            MATCH_LENGTH_ACCURACY_LOG,
1488        )
1489        .unwrap();
1490
1491        let mut encoder = TansEncoder::from_decode_table(&ml_table);
1492
1493        // Test code 38
1494        encoder.init_state(38);
1495        let state_38 = encoder.get_state();
1496        let decode_38 = ml_table.decode(state_38 as usize);
1497        println!(
1498            "ML code 38 -> state {} -> decodes to symbol {}",
1499            state_38, decode_38.symbol
1500        );
1501
1502        // Test code 43
1503        encoder.init_state(43);
1504        let state_43 = encoder.get_state();
1505        let decode_43 = ml_table.decode(state_43 as usize);
1506        println!(
1507            "ML code 43 -> state {} -> decodes to symbol {}",
1508            state_43, decode_43.symbol
1509        );
1510
1511        // Verify they decode correctly
1512        assert_eq!(
1513            decode_38.symbol, 38,
1514            "State {} should decode to symbol 38",
1515            state_38
1516        );
1517        assert_eq!(
1518            decode_43.symbol, 43,
1519            "State {} should decode to symbol 43",
1520            state_43
1521        );
1522    }
1523
1524    #[test]
1525    fn test_ll_code_23() {
1526        println!("\n=== LL Code 23 State Mapping ===");
1527
1528        // Build LL table
1529        let ll_table = FseTable::from_predefined(
1530            &LITERAL_LENGTH_DEFAULT_DISTRIBUTION,
1531            LITERAL_LENGTH_ACCURACY_LOG,
1532        )
1533        .unwrap();
1534
1535        let mut encoder = TansEncoder::from_decode_table(&ll_table);
1536
1537        // Test code 23
1538        encoder.init_state(23);
1539        let state_23 = encoder.get_state();
1540        let decode_23 = ll_table.decode(state_23 as usize);
1541        println!(
1542            "LL code 23 -> state {} -> decodes to symbol {}",
1543            state_23, decode_23.symbol
1544        );
1545
1546        // Verify it decodes correctly
1547        assert_eq!(
1548            decode_23.symbol, 23,
1549            "State {} should decode to symbol 23",
1550            state_23
1551        );
1552    }
1553}
1554
1555#[cfg(test)]
1556mod trace_tests {
1557    use super::*;
1558    use crate::fse::{
1559        LITERAL_LENGTH_ACCURACY_LOG, LITERAL_LENGTH_DEFAULT_DISTRIBUTION,
1560        MATCH_LENGTH_ACCURACY_LOG, MATCH_LENGTH_DEFAULT_DISTRIBUTION, OFFSET_ACCURACY_LOG,
1561        OFFSET_DEFAULT_DISTRIBUTION,
1562    };
1563
1564    #[test]
1565    fn test_trace_encode_sequence() {
1566        println!("\n=== Trace FSE Encode Sequence ===\n");
1567
1568        // Build tables
1569        let ll_table = FseTable::from_predefined(
1570            &LITERAL_LENGTH_DEFAULT_DISTRIBUTION,
1571            LITERAL_LENGTH_ACCURACY_LOG,
1572        )
1573        .unwrap();
1574        let of_table =
1575            FseTable::from_predefined(&OFFSET_DEFAULT_DISTRIBUTION, OFFSET_ACCURACY_LOG).unwrap();
1576        let ml_table = FseTable::from_predefined(
1577            &MATCH_LENGTH_DEFAULT_DISTRIBUTION,
1578            MATCH_LENGTH_ACCURACY_LOG,
1579        )
1580        .unwrap();
1581
1582        // Build encoders
1583        let mut ll_enc = TansEncoder::from_decode_table(&ll_table);
1584        let mut of_enc = TansEncoder::from_decode_table(&of_table);
1585        let mut ml_enc = TansEncoder::from_decode_table(&ml_table);
1586
1587        // Print symbol params for relevant symbols
1588        println!(
1589            "LL symbol 0 params: delta_nb_bits={}, delta_find_state={}",
1590            ll_enc.symbol_params[0].delta_nb_bits, ll_enc.symbol_params[0].delta_find_state
1591        );
1592        println!(
1593            "LL symbol 4 params: delta_nb_bits={}, delta_find_state={}",
1594            ll_enc.symbol_params[4].delta_nb_bits, ll_enc.symbol_params[4].delta_find_state
1595        );
1596
1597        // Init with seq[1] codes
1598        ll_enc.init_state(0);
1599        of_enc.init_state(2);
1600        ml_enc.init_state(43);
1601
1602        let ll_s0 = ll_enc.state;
1603        let of_s0 = of_enc.state;
1604        let ml_s0 = ml_enc.state;
1605        println!("\nAfter init:");
1606        println!(
1607            "  LL: encoder_state={}, decoder_state={}",
1608            ll_s0,
1609            ll_s0 - 64
1610        );
1611        println!(
1612            "  OF: encoder_state={}, decoder_state={}",
1613            of_s0,
1614            of_s0 - 32
1615        );
1616        println!(
1617            "  ML: encoder_state={}, decoder_state={}",
1618            ml_s0,
1619            ml_s0 - 64
1620        );
1621
1622        // Encode seq[0] codes
1623        println!("\nEncoding seq[0] codes (4, 2, 45):");
1624
1625        // LL
1626        let ll_params = &ll_enc.symbol_params[4];
1627        let ll_nb = ((ll_s0 as u64 + ll_params.delta_nb_bits as u64) >> 16) as u8;
1628        let ll_bits = ll_s0 & ((1u32 << ll_nb) - 1);
1629        println!(
1630            "  LL: state={}, delta_nb_bits={}, nb_bits_out={}, bits={}",
1631            ll_s0, ll_params.delta_nb_bits, ll_nb, ll_bits
1632        );
1633        let (ll_out_bits, ll_out_nb) = ll_enc.encode_symbol(4);
1634        println!(
1635            "  LL encode_symbol output: bits={}, nb={}",
1636            ll_out_bits, ll_out_nb
1637        );
1638
1639        // OF
1640        let of_params = &of_enc.symbol_params[2];
1641        let of_nb = ((of_s0 as u64 + of_params.delta_nb_bits as u64) >> 16) as u8;
1642        let of_bits = of_s0 & ((1u32 << of_nb) - 1);
1643        println!(
1644            "  OF: state={}, delta_nb_bits={}, nb_bits_out={}, bits={}",
1645            of_s0, of_params.delta_nb_bits, of_nb, of_bits
1646        );
1647        let (of_out_bits, of_out_nb) = of_enc.encode_symbol(2);
1648        println!(
1649            "  OF encode_symbol output: bits={}, nb={}",
1650            of_out_bits, of_out_nb
1651        );
1652
1653        // ML
1654        let ml_params = &ml_enc.symbol_params[45];
1655        let ml_nb = ((ml_s0 as u64 + ml_params.delta_nb_bits as u64) >> 16) as u8;
1656        let ml_bits = ml_s0 & ((1u32 << ml_nb) - 1);
1657        println!(
1658            "  ML: state={}, delta_nb_bits={}, nb_bits_out={}, bits={}",
1659            ml_s0, ml_params.delta_nb_bits, ml_nb, ml_bits
1660        );
1661        let (ml_out_bits, ml_out_nb) = ml_enc.encode_symbol(45);
1662        println!(
1663            "  ML encode_symbol output: bits={}, nb={}",
1664            ml_out_bits, ml_out_nb
1665        );
1666
1667        let ll_s1 = ll_enc.state;
1668        let of_s1 = of_enc.state;
1669        let ml_s1 = ml_enc.state;
1670        println!("\nAfter encode:");
1671        println!(
1672            "  LL: encoder_state={}, decoder_state={}",
1673            ll_s1,
1674            ll_s1 - 64
1675        );
1676        println!(
1677            "  OF: encoder_state={}, decoder_state={}",
1678            of_s1,
1679            of_s1 - 32
1680        );
1681        println!(
1682            "  ML: encoder_state={}, decoder_state={}",
1683            ml_s1,
1684            ml_s1 - 64
1685        );
1686
1687        // Verify decode table
1688        println!("\nDecode table verification:");
1689        println!(
1690            "  LL[{}] = symbol {}",
1691            ll_s1 - 64,
1692            ll_table.decode((ll_s1 - 64) as usize).symbol
1693        );
1694        println!(
1695            "  OF[{}] = symbol {}",
1696            of_s1 - 32,
1697            of_table.decode((of_s1 - 32) as usize).symbol
1698        );
1699        println!(
1700            "  ML[{}] = symbol {}",
1701            ml_s1 - 64,
1702            ml_table.decode((ml_s1 - 64) as usize).symbol
1703        );
1704    }
1705}