Skip to main content

haagenti_zstd/fse/
encoder.rs

1//! FSE stream encoder.
2//!
3//! Implements the Finite State Entropy encoder for Zstandard compression.
4//! FSE is a variant of ANS (Asymmetric Numeral Systems) that provides
5//! near-optimal compression with very fast decoding.
6//!
7//! ## Optimizations
8//!
9//! This implementation uses several optimizations:
10//! - Flat encoding table for cache efficiency
11//! - Pre-computed symbol state indices for O(1) lookup
12//! - Packed encoding entries (64-bit)
13//! - Zero-copy state transitions
14//!
15//! ## References
16//!
17//! - [RFC 8878 Section 4.1](https://datatracker.ietf.org/doc/html/rfc8878#section-4.1)
18//! - [Asymmetric Numeral Systems](https://arxiv.org/abs/0902.0271)
19
20use super::table::FseTable;
21
22/// Packed FSE encoding entry for cache efficiency.
23/// Layout: [baseline:16][num_bits:8][symbol:8][state_offset:32]
24#[derive(Debug, Clone, Copy, Default)]
25#[repr(C, align(8))]
26pub struct FseEncodeEntry {
27    /// Number of bits to output for this symbol.
28    pub num_bits: u8,
29    /// Base value for computing next state.
30    pub delta_find_state: i16,
31    /// State offset for this symbol occurrence (unused in current impl).
32    #[allow(dead_code)]
33    pub delta_nb_bits: u16,
34}
35
36/// Optimized FSE encoder using flat table layout.
37///
38/// The encoder uses a flat array indexed by (symbol, occurrence) for
39/// cache-efficient encoding. Symbol frequency counts are updated inline.
40#[derive(Debug)]
41pub struct FseEncoder {
42    /// Flat encoding table: indexed by (symbol * max_states + state_index)
43    /// Each symbol has up to max_states encoding entries.
44    encode_table: Vec<FseEncodeEntry>,
45    /// Number of states per symbol (table_size / num_symbols average)
46    #[allow(dead_code)]
47    states_per_symbol: usize,
48    /// Symbol count array: number of states for each symbol
49    symbol_counts: [u16; 256],
50    /// Symbol start indices in the flat table
51    symbol_starts: [u32; 256],
52    /// Current symbol occurrence counters
53    symbol_next: [u16; 256],
54    /// Current encoder state.
55    state: usize,
56    /// Accuracy log.
57    accuracy_log: u8,
58    /// Table size
59    #[allow(dead_code)]
60    table_size: usize,
61}
62
63impl FseEncoder {
64    /// Build an optimized FSE encoder from a decoding table.
65    ///
66    /// Uses a flat table layout for cache efficiency.
67    pub fn from_decode_table(decode_table: &FseTable) -> Self {
68        let accuracy_log = decode_table.accuracy_log();
69        let table_size = decode_table.size();
70
71        // Count states per symbol
72        let mut symbol_counts = [0u16; 256];
73        for state in 0..table_size {
74            let entry = decode_table.decode(state);
75            symbol_counts[entry.symbol as usize] += 1;
76        }
77
78        // Calculate symbol start indices
79        let mut symbol_starts = [0u32; 256];
80        let mut offset = 0u32;
81        for i in 0..256 {
82            symbol_starts[i] = offset;
83            offset += symbol_counts[i] as u32;
84        }
85
86        // Build flat encoding table
87        let total_entries = table_size;
88        let mut encode_table = vec![FseEncodeEntry::default(); total_entries];
89        let mut symbol_next_temp = [0u16; 256];
90
91        for state in 0..table_size {
92            let decode_entry = decode_table.decode(state);
93            let symbol = decode_entry.symbol as usize;
94
95            let idx = symbol_starts[symbol] as usize + symbol_next_temp[symbol] as usize;
96            symbol_next_temp[symbol] += 1;
97
98            if idx < encode_table.len() {
99                encode_table[idx] = FseEncodeEntry {
100                    num_bits: decode_entry.num_bits,
101                    delta_find_state: state as i16,
102                    delta_nb_bits: (decode_entry.num_bits as u16) << 8
103                        | (decode_entry.baseline & 0xFF),
104                };
105            }
106        }
107
108        Self {
109            encode_table,
110            states_per_symbol: table_size / 256,
111            symbol_counts,
112            symbol_starts,
113            symbol_next: [0u16; 256],
114            state: 0,
115            accuracy_log,
116            table_size,
117        }
118    }
119
120    /// Initialize the encoder with a symbol (first symbol sets initial state).
121    #[inline]
122    pub fn init_state(&mut self, symbol: u8) {
123        let sym_idx = symbol as usize;
124        if self.symbol_counts[sym_idx] > 0 {
125            let entry_idx = self.symbol_starts[sym_idx] as usize;
126            if entry_idx < self.encode_table.len() {
127                self.state = self.encode_table[entry_idx].delta_find_state as usize;
128            }
129        }
130        // Reset occurrence counters
131        self.symbol_next = [0u16; 256];
132    }
133
134    /// Get the current state for serialization.
135    #[inline]
136    pub fn get_state(&self) -> usize {
137        self.state
138    }
139
140    /// Get the accuracy log.
141    #[inline]
142    pub fn accuracy_log(&self) -> u8 {
143        self.accuracy_log
144    }
145
146    /// Encode a symbol, returning the bits to output.
147    ///
148    /// Returns (bits_value, num_bits) where bits_value contains num_bits to output.
149    #[inline]
150    pub fn encode_symbol(&mut self, symbol: u8) -> (u32, u8) {
151        let sym_idx = symbol as usize;
152        let count = self.symbol_counts[sym_idx];
153
154        if count == 0 {
155            return (0, 0);
156        }
157
158        // Get entry for this symbol's current occurrence
159        let occurrence = self.symbol_next[sym_idx] % count;
160        let entry_idx = self.symbol_starts[sym_idx] as usize + occurrence as usize;
161
162        self.symbol_next[sym_idx] = self.symbol_next[sym_idx].wrapping_add(1);
163
164        if entry_idx >= self.encode_table.len() {
165            return (0, 0);
166        }
167
168        let entry = &self.encode_table[entry_idx];
169        let num_bits = entry.num_bits;
170        let mask = (1u32 << num_bits) - 1;
171        let bits = (self.state as u32) & mask;
172
173        // Update state
174        self.state = entry.delta_find_state as usize;
175
176        (bits, num_bits)
177    }
178
179    /// Reset for encoding a new stream.
180    #[inline]
181    pub fn reset(&mut self) {
182        self.state = 0;
183        self.symbol_next = [0u16; 256];
184    }
185}
186
187/// Optimized FSE bitstream writer.
188///
189/// Uses a 64-bit buffer for efficient bit packing with minimal flushes.
190#[derive(Debug)]
191pub struct FseBitWriter {
192    /// Output buffer.
193    buffer: Vec<u8>,
194    /// Current 64-bit accumulator.
195    accum: u64,
196    /// Bits currently in accumulator (0-56).
197    bits_in_accum: u32,
198}
199
200impl FseBitWriter {
201    /// Create a new bit writer with pre-allocated capacity.
202    #[inline]
203    pub fn new() -> Self {
204        Self::with_capacity(256)
205    }
206
207    /// Create a new bit writer with specified capacity.
208    #[inline]
209    pub fn with_capacity(capacity: usize) -> Self {
210        Self {
211            buffer: Vec::with_capacity(capacity),
212            accum: 0,
213            bits_in_accum: 0,
214        }
215    }
216
217    /// Write bits to the stream.
218    ///
219    /// Uses a 64-bit accumulator to minimize flush operations.
220    #[inline]
221    pub fn write_bits(&mut self, value: u32, num_bits: u8) {
222        if num_bits == 0 {
223            return;
224        }
225
226        // Add bits to accumulator
227        self.accum |= (value as u64) << self.bits_in_accum;
228        self.bits_in_accum += num_bits as u32;
229
230        // Flush complete bytes when we have 8+ bytes (64 bits)
231        // This is rare, so we optimize for the common case
232        if self.bits_in_accum >= 56 {
233            self.flush_bytes();
234        }
235    }
236
237    /// Flush complete bytes from accumulator to buffer.
238    #[inline(always)]
239    fn flush_bytes(&mut self) {
240        // Flush 32 bits (4 bytes) at a time when possible for efficiency
241        // This is faster than byte-by-byte while still being correct
242        while self.bits_in_accum >= 32 {
243            let bytes = (self.accum as u32).to_le_bytes();
244            self.buffer.extend_from_slice(&bytes);
245            self.accum >>= 32;
246            self.bits_in_accum -= 32;
247        }
248        // Flush remaining complete bytes one at a time
249        while self.bits_in_accum >= 8 {
250            self.buffer.push((self.accum & 0xFF) as u8);
251            self.accum >>= 8;
252            self.bits_in_accum -= 8;
253        }
254    }
255
256    /// Finish the bitstream, adding sentinel bit.
257    pub fn finish(mut self) -> Vec<u8> {
258        // Add sentinel bit
259        self.write_bits(1, 1);
260
261        // Flush all complete bytes
262        self.flush_bytes();
263
264        // Flush remaining partial byte
265        if self.bits_in_accum > 0 {
266            self.buffer.push(self.accum as u8);
267        }
268
269        self.buffer
270    }
271
272    /// Get the accumulated bits without finishing.
273    pub fn into_bytes(mut self) -> Vec<u8> {
274        // Flush all complete bytes first
275        self.flush_bytes();
276
277        // Flush remaining partial byte
278        if self.bits_in_accum > 0 {
279            self.buffer.push(self.accum as u8);
280        }
281        self.buffer
282    }
283
284    /// Get current size in bytes (approximate).
285    #[inline]
286    pub fn len(&self) -> usize {
287        self.buffer.len() + (self.bits_in_accum as usize).div_ceil(8)
288    }
289
290    /// Check if the writer is empty.
291    #[inline]
292    pub fn is_empty(&self) -> bool {
293        self.buffer.is_empty() && self.bits_in_accum == 0
294    }
295}
296
297impl Default for FseBitWriter {
298    fn default() -> Self {
299        Self::new()
300    }
301}
302
303// =============================================================================
304// Interleaved FSE Encoder for Sequences
305// =============================================================================
306
307/// Interleaved FSE encoder for Zstd sequences.
308///
309/// Zstd sequences use three interleaved FSE streams:
310/// - Literal Length (LL)
311/// - Offset (OF)
312/// - Match Length (ML)
313///
314/// Each stream maintains its own state, and bits are interleaved in a
315/// specific order for optimal decoding performance.
316#[derive(Debug)]
317pub struct InterleavedFseEncoder {
318    /// Literal length encoder
319    ll_encoder: FseEncoder,
320    /// Offset encoder
321    of_encoder: FseEncoder,
322    /// Match length encoder
323    ml_encoder: FseEncoder,
324}
325
326impl InterleavedFseEncoder {
327    /// Create a new interleaved encoder from the three FSE tables.
328    pub fn new(ll_table: &FseTable, of_table: &FseTable, ml_table: &FseTable) -> Self {
329        Self {
330            ll_encoder: FseEncoder::from_decode_table(ll_table),
331            of_encoder: FseEncoder::from_decode_table(of_table),
332            ml_encoder: FseEncoder::from_decode_table(ml_table),
333        }
334    }
335
336    /// Initialize all three encoders with their first symbols.
337    #[inline]
338    pub fn init_states(&mut self, ll: u8, of: u8, ml: u8) {
339        self.ll_encoder.init_state(ll);
340        self.of_encoder.init_state(of);
341        self.ml_encoder.init_state(ml);
342    }
343
344    /// Encode one sequence triplet (LL, OF, ML).
345    ///
346    /// Returns the bits and counts for each stream in the correct order.
347    #[inline]
348    pub fn encode_sequence(&mut self, ll: u8, of: u8, ml: u8) -> [(u32, u8); 3] {
349        // Encoding order for Zstd: OF, ML, LL
350        let of_bits = self.of_encoder.encode_symbol(of);
351        let ml_bits = self.ml_encoder.encode_symbol(ml);
352        let ll_bits = self.ll_encoder.encode_symbol(ll);
353
354        [of_bits, ml_bits, ll_bits]
355    }
356
357    /// Get the final states for all three encoders.
358    #[inline]
359    pub fn get_states(&self) -> (usize, usize, usize) {
360        (
361            self.ll_encoder.get_state(),
362            self.of_encoder.get_state(),
363            self.ml_encoder.get_state(),
364        )
365    }
366
367    /// Get accuracy logs for all three encoders.
368    #[inline]
369    pub fn accuracy_logs(&self) -> (u8, u8, u8) {
370        (
371            self.ll_encoder.accuracy_log(),
372            self.of_encoder.accuracy_log(),
373            self.ml_encoder.accuracy_log(),
374        )
375    }
376
377    /// Reset all encoders for a new sequence section.
378    #[inline]
379    pub fn reset(&mut self) {
380        self.ll_encoder.reset();
381        self.of_encoder.reset();
382        self.ml_encoder.reset();
383    }
384}
385
386// =============================================================================
387// Tests
388// =============================================================================
389
390#[cfg(test)]
391mod tests {
392    use super::*;
393    use crate::fse::{FseTable, LITERAL_LENGTH_ACCURACY_LOG, LITERAL_LENGTH_DEFAULT_DISTRIBUTION};
394
395    #[test]
396    fn test_fse_encoder_creation() {
397        let table = FseTable::from_predefined(
398            &LITERAL_LENGTH_DEFAULT_DISTRIBUTION,
399            LITERAL_LENGTH_ACCURACY_LOG,
400        )
401        .unwrap();
402
403        let encoder = FseEncoder::from_decode_table(&table);
404        assert_eq!(encoder.accuracy_log(), LITERAL_LENGTH_ACCURACY_LOG);
405    }
406
407    #[test]
408    fn test_fse_encoder_init_state() {
409        let table = FseTable::from_predefined(
410            &LITERAL_LENGTH_DEFAULT_DISTRIBUTION,
411            LITERAL_LENGTH_ACCURACY_LOG,
412        )
413        .unwrap();
414
415        let mut encoder = FseEncoder::from_decode_table(&table);
416        encoder.init_state(0);
417
418        // State should be valid (within table bounds)
419        assert!(encoder.get_state() < table.size());
420    }
421
422    #[test]
423    fn test_fse_encoder_encode_symbol() {
424        let table = FseTable::from_predefined(
425            &LITERAL_LENGTH_DEFAULT_DISTRIBUTION,
426            LITERAL_LENGTH_ACCURACY_LOG,
427        )
428        .unwrap();
429
430        let mut encoder = FseEncoder::from_decode_table(&table);
431        encoder.init_state(0);
432
433        // Encode a few symbols
434        for _ in 0..10 {
435            let (bits, num_bits) = encoder.encode_symbol(0);
436            // num_bits should be reasonable
437            assert!(num_bits <= LITERAL_LENGTH_ACCURACY_LOG);
438            // bits should fit in num_bits
439            assert!(bits < (1 << num_bits) || num_bits == 0);
440        }
441    }
442
443    #[test]
444    fn test_fse_bit_writer_simple() {
445        let mut writer = FseBitWriter::new();
446        writer.write_bits(0b101, 3);
447        let result = writer.finish();
448
449        // Should have bits + sentinel
450        assert!(!result.is_empty());
451    }
452
453    #[test]
454    fn test_fse_bit_writer_multi_byte() {
455        let mut writer = FseBitWriter::new();
456        writer.write_bits(0xAB, 8);
457        writer.write_bits(0xCD, 8);
458        let result = writer.into_bytes();
459
460        assert_eq!(result.len(), 2);
461        assert_eq!(result[0], 0xAB);
462        assert_eq!(result[1], 0xCD);
463    }
464
465    #[test]
466    fn test_fse_bit_writer_capacity() {
467        let writer = FseBitWriter::with_capacity(1024);
468        assert!(writer.is_empty());
469    }
470
471    #[test]
472    fn test_fse_bit_writer_len() {
473        let mut writer = FseBitWriter::new();
474        writer.write_bits(0xFF, 8);
475        assert_eq!(writer.len(), 1);
476
477        writer.write_bits(0xFF, 8);
478        assert_eq!(writer.len(), 2);
479    }
480
481    #[test]
482    fn test_fse_bit_writer_large() {
483        let mut writer = FseBitWriter::new();
484
485        // Write many bytes
486        for i in 0..1000 {
487            writer.write_bits((i % 256) as u32, 8);
488        }
489
490        let result = writer.into_bytes();
491        assert_eq!(result.len(), 1000);
492    }
493
494    #[test]
495    fn test_interleaved_encoder() {
496        use crate::fse::{
497            MATCH_LENGTH_ACCURACY_LOG, MATCH_LENGTH_DEFAULT_DISTRIBUTION, OFFSET_ACCURACY_LOG,
498            OFFSET_DEFAULT_DISTRIBUTION,
499        };
500
501        let ll_table = FseTable::from_predefined(
502            &LITERAL_LENGTH_DEFAULT_DISTRIBUTION,
503            LITERAL_LENGTH_ACCURACY_LOG,
504        )
505        .unwrap();
506        let ml_table = FseTable::from_predefined(
507            &MATCH_LENGTH_DEFAULT_DISTRIBUTION,
508            MATCH_LENGTH_ACCURACY_LOG,
509        )
510        .unwrap();
511        let of_table =
512            FseTable::from_predefined(&OFFSET_DEFAULT_DISTRIBUTION, OFFSET_ACCURACY_LOG).unwrap();
513
514        let mut encoder = InterleavedFseEncoder::new(&ll_table, &of_table, &ml_table);
515        encoder.init_states(0, 0, 0);
516
517        // Encode a sequence
518        let [of_bits, ml_bits, ll_bits] = encoder.encode_sequence(0, 0, 0);
519
520        // All should produce valid outputs
521        assert!(of_bits.1 <= OFFSET_ACCURACY_LOG);
522        assert!(ml_bits.1 <= MATCH_LENGTH_ACCURACY_LOG);
523        assert!(ll_bits.1 <= LITERAL_LENGTH_ACCURACY_LOG);
524    }
525}