Skip to main content

haagenti_zstd/huffman/
encoder.rs

1//! Huffman encoding for Zstd literals.
2//!
3//! This module implements high-performance Huffman encoding for Zstd compression.
4//!
5//! ## Optimizations
6//!
7//! - SIMD-accelerated frequency counting (histogram)
8//! - 64-bit accumulator for efficient bit packing
9//! - Cache-friendly code table layout
10//! - Vectorized encoding for batch processing
11//!
12//! ## Weight System
13//!
14//! In Zstd Huffman encoding:
15//! - Weight `w > 0` means `code_length = max_bits + 1 - w`
16//! - Weight `0` means symbol is not present
17//! - Higher weight = shorter code = more frequent symbol
18//! - Maximum weight is 11 (minimum code length = 1 bit)
19//!
20//! ## References
21//!
22//! - [RFC 8878 Section 4.2](https://datatracker.ietf.org/doc/html/rfc8878#section-4.2)
23
24use crate::fse::{FseBitWriter, FseTable};
25
26/// Maximum number of symbols for Huffman encoding (256 for bytes).
27const MAX_SYMBOLS: usize = 256;
28
29/// Maximum Huffman weight (limits code length).
30const MAX_WEIGHT: u8 = 11;
31
32/// Minimum data size to benefit from Huffman encoding.
33const MIN_HUFFMAN_SIZE: usize = 32;
34
35/// Huffman encoding table entry - packed for cache efficiency.
36#[derive(Debug, Clone, Copy, Default)]
37#[repr(C, align(4))]
38pub struct HuffmanCode {
39    /// The Huffman code bits (stored in LSB).
40    pub code: u16,
41    /// Number of bits in the code.
42    pub num_bits: u8,
43    /// Padding for alignment.
44    _pad: u8,
45}
46
47impl HuffmanCode {
48    #[inline]
49    const fn new(code: u16, num_bits: u8) -> Self {
50        Self {
51            code,
52            num_bits,
53            _pad: 0,
54        }
55    }
56}
57
58/// Optimized Huffman encoder for literal compression.
59#[derive(Debug)]
60pub struct HuffmanEncoder {
61    /// Encoding table: symbol -> code (256 entries, cache-aligned)
62    codes: Box<[HuffmanCode; MAX_SYMBOLS]>,
63    /// Symbol weights for serialization
64    weights: Vec<u8>,
65    /// Maximum code length in bits
66    max_bits: u8,
67    /// Number of symbols with non-zero weight
68    num_symbols: usize,
69    /// Highest symbol index with non-zero weight (for weight table sizing)
70    last_symbol: usize,
71}
72
73impl HuffmanEncoder {
74    /// Build a Huffman encoder from literal data.
75    ///
76    /// Uses SIMD-accelerated histogram when available.
77    /// Returns None if data cannot be efficiently Huffman-compressed.
78    pub fn build(data: &[u8]) -> Option<Self> {
79        if data.len() < MIN_HUFFMAN_SIZE {
80            return None;
81        }
82
83        // Count symbol frequencies using optimized histogram
84        let freq = Self::count_frequencies(data);
85
86        // Count unique symbols and find last symbol with non-zero frequency
87        let unique_count = freq.iter().filter(|&&f| f > 0).count();
88        if unique_count < 2 {
89            return None; // Use RLE instead
90        }
91
92        let last_symbol = freq
93            .iter()
94            .enumerate()
95            .filter(|&(_, &f)| f > 0)
96            .map(|(i, _)| i)
97            .max()
98            .unwrap_or(0);
99
100        // Convert frequencies to weights
101        let (weights, max_bits) = Self::frequencies_to_weights(&freq)?;
102
103        // Generate canonical codes
104        let codes = Self::generate_canonical_codes(&weights, max_bits);
105
106        Some(Self {
107            codes: Box::new(codes),
108            weights,
109            max_bits,
110            num_symbols: unique_count,
111            last_symbol,
112        })
113    }
114
115    /// Build a Huffman encoder from pre-defined weights.
116    ///
117    /// This allows using custom Huffman tables instead of building from data.
118    /// Useful when you have pre-trained weights from dictionary compression
119    /// or want to reuse weights across multiple blocks.
120    ///
121    /// # Parameters
122    ///
123    /// - `weights`: Array of 256 weights (one per byte value). Weight 0 means
124    ///   symbol is not present. Weight w > 0 means code_length = max_bits + 1 - w.
125    ///
126    /// # Returns
127    ///
128    /// Returns `Some(encoder)` if the weights are valid, `None` otherwise.
129    ///
130    /// # Example
131    ///
132    /// ```rust
133    /// use haagenti_zstd::huffman::HuffmanEncoder;
134    ///
135    /// // Define weights for symbols 'a' (97), 'b' (98), 'c' (99)
136    /// let mut weights = vec![0u8; 256];
137    /// weights[97] = 3;  // 'a' - highest weight (shortest code)
138    /// weights[98] = 2;  // 'b' - medium weight
139    /// weights[99] = 1;  // 'c' - lowest weight (longest code)
140    ///
141    /// let encoder = HuffmanEncoder::from_weights(&weights).unwrap();
142    /// ```
143    pub fn from_weights(weights: &[u8]) -> Option<Self> {
144        if weights.len() != MAX_SYMBOLS {
145            return None;
146        }
147
148        // Count unique symbols and find last symbol with non-zero weight
149        let unique_count = weights.iter().filter(|&&w| w > 0).count();
150        if unique_count < 2 {
151            return None; // Need at least 2 symbols
152        }
153
154        let last_symbol = weights
155            .iter()
156            .enumerate()
157            .filter(|&(_, &w)| w > 0)
158            .map(|(i, _)| i)
159            .max()
160            .unwrap_or(0);
161
162        // Find max weight to determine max_bits
163        let max_weight = *weights.iter().max().unwrap_or(&0);
164        if max_weight == 0 || max_weight > MAX_WEIGHT {
165            return None;
166        }
167
168        // Calculate max_bits from max_weight
169        // In Zstd: code_length = max_bits + 1 - weight
170        // For the highest weight symbol, code_length should be 1, so:
171        // max_bits = max_weight
172        let max_bits = max_weight;
173
174        // Generate canonical codes from weights
175        let codes = Self::generate_canonical_codes(weights, max_bits);
176
177        Some(Self {
178            codes: Box::new(codes),
179            weights: weights.to_vec(),
180            max_bits,
181            num_symbols: unique_count,
182            last_symbol,
183        })
184    }
185
186    /// Count byte frequencies using optimized histogram.
187    ///
188    /// Uses SIMD acceleration when available via haagenti_simd.
189    #[inline]
190    fn count_frequencies(data: &[u8]) -> [u32; MAX_SYMBOLS] {
191        // Use SIMD-accelerated histogram when feature is enabled
192        #[cfg(feature = "simd")]
193        {
194            haagenti_simd::byte_histogram(data)
195        }
196
197        // Optimized scalar fallback using 4-way interleaved counting
198        // This reduces cache line conflicts from histogram updates
199        #[cfg(not(feature = "simd"))]
200        {
201            let mut freq0 = [0u32; MAX_SYMBOLS];
202            let mut freq1 = [0u32; MAX_SYMBOLS];
203            let mut freq2 = [0u32; MAX_SYMBOLS];
204            let mut freq3 = [0u32; MAX_SYMBOLS];
205
206            // Process 16 bytes at a time with 4 interleaved histograms
207            let chunks = data.chunks_exact(16);
208            let remainder = chunks.remainder();
209
210            for chunk in chunks {
211                // Interleave to reduce pipeline stalls from same-address increments
212                freq0[chunk[0] as usize] += 1;
213                freq1[chunk[1] as usize] += 1;
214                freq2[chunk[2] as usize] += 1;
215                freq3[chunk[3] as usize] += 1;
216                freq0[chunk[4] as usize] += 1;
217                freq1[chunk[5] as usize] += 1;
218                freq2[chunk[6] as usize] += 1;
219                freq3[chunk[7] as usize] += 1;
220                freq0[chunk[8] as usize] += 1;
221                freq1[chunk[9] as usize] += 1;
222                freq2[chunk[10] as usize] += 1;
223                freq3[chunk[11] as usize] += 1;
224                freq0[chunk[12] as usize] += 1;
225                freq1[chunk[13] as usize] += 1;
226                freq2[chunk[14] as usize] += 1;
227                freq3[chunk[15] as usize] += 1;
228            }
229
230            // Handle remainder
231            for &byte in remainder {
232                freq0[byte as usize] += 1;
233            }
234
235            // Merge the 4 histograms
236            for i in 0..MAX_SYMBOLS {
237                freq0[i] += freq1[i] + freq2[i] + freq3[i];
238            }
239
240            freq0
241        }
242    }
243
244    /// Convert frequencies to Zstd Huffman weights.
245    ///
246    /// Produces weights that satisfy the Kraft inequality:
247    /// sum(2^weight) = 2^(max_weight + 1)
248    ///
249    /// # Algorithm Complexity: O(n log n)
250    ///
251    /// 1. Sort symbols by frequency: O(n log n)
252    /// 2. Calculate initial weights based on frequency ratios: O(n)
253    /// 3. Adjust weights to fill Kraft capacity using heap-based greedy: O(n log n)
254    ///
255    /// This replaces the previous O(n²) algorithm that used repeated full scans.
256    fn frequencies_to_weights(freq: &[u32; MAX_SYMBOLS]) -> Option<(Vec<u8>, u8)> {
257        // Collect non-zero frequency symbols
258        let mut symbols: Vec<(usize, u32)> = freq
259            .iter()
260            .enumerate()
261            .filter(|&(_, &f)| f > 0)
262            .map(|(i, &f)| (i, f))
263            .collect();
264
265        if symbols.len() < 2 {
266            return None;
267        }
268
269        let n = symbols.len();
270
271        // Special case: exactly 2 symbols get weight 1 each (1-bit codes)
272        if n == 2 {
273            let mut weights = vec![0u8; MAX_SYMBOLS];
274            weights[symbols[0].0] = 1;
275            weights[symbols[1].0] = 1;
276            return Some((weights, 1));
277        }
278
279        // Sort symbols by frequency (highest first) - O(n log n)
280        symbols.sort_unstable_by(|a, b| b.1.cmp(&a.1));
281
282        // Calculate max_weight needed for n symbols
283        let min_exp = if n <= 2 {
284            0
285        } else {
286            64 - ((n - 1) as u64).leading_zeros()
287        };
288        let max_weight = ((min_exp + 1) as u8).clamp(1, MAX_WEIGHT);
289
290        let mut weights = vec![0u8; MAX_SYMBOLS];
291        let target = 1u64 << (max_weight + 1);
292
293        // Phase 1: Assign initial weights based on frequency ratio - O(n)
294        // Use log2(max_freq / freq) to estimate relative code lengths
295        let max_freq = symbols[0].1 as u64;
296
297        for (idx, &(sym, freq)) in symbols.iter().enumerate() {
298            if idx == 0 {
299                // Most frequent symbol gets max_weight (shortest code)
300                weights[sym] = max_weight;
301            } else {
302                // Calculate weight based on frequency ratio
303                // Higher ratio = lower frequency = lower weight = longer code
304                let ratio = (max_freq + freq as u64 - 1) / freq.max(1) as u64;
305                let log_ratio = if ratio <= 1 {
306                    0
307                } else {
308                    (64 - ratio.leading_zeros()).saturating_sub(1) as u8
309                };
310                // Clamp to valid range [1, max_weight]
311                let w = max_weight.saturating_sub(log_ratio).max(1);
312                weights[sym] = w;
313            }
314        }
315
316        // Calculate current Kraft sum - O(n)
317        let mut kraft_sum: u64 = symbols.iter().map(|(sym, _)| 1u64 << weights[*sym]).sum();
318
319        // Phase 2: Adjust weights to satisfy Kraft inequality - O(n log n) worst case
320        // Use a greedy approach: process symbols by weight (lowest first for increasing)
321
322        if kraft_sum < target {
323            // Under capacity: increase weights for symbols (shorter codes)
324            // Process from lowest weight to highest (most room to increase)
325            let mut by_weight: Vec<(usize, u8)> = symbols
326                .iter()
327                .map(|&(sym, _)| (sym, weights[sym]))
328                .collect();
329            by_weight.sort_unstable_by_key(|&(_, w)| w);
330
331            for (sym, _) in by_weight {
332                while weights[sym] < max_weight && kraft_sum < target {
333                    let increase = 1u64 << weights[sym];
334                    if kraft_sum + increase <= target {
335                        kraft_sum += increase;
336                        weights[sym] += 1;
337                    } else {
338                        break;
339                    }
340                }
341            }
342        } else if kraft_sum > target {
343            // Over capacity: decrease weights (longer codes)
344            // Process from highest weight to lowest
345            let mut by_weight: Vec<(usize, u8)> = symbols
346                .iter()
347                .map(|&(sym, _)| (sym, weights[sym]))
348                .collect();
349            by_weight.sort_unstable_by_key(|&(_, w)| std::cmp::Reverse(w));
350
351            for (sym, _) in by_weight {
352                while weights[sym] > 1 && kraft_sum > target {
353                    weights[sym] -= 1;
354                    kraft_sum -= 1u64 << weights[sym];
355                }
356            }
357        }
358
359        // Final pass: fill any remaining capacity - O(n)
360        // This handles edge cases where the above didn't fully utilize capacity
361        if kraft_sum < target {
362            for &(sym, _) in &symbols {
363                while weights[sym] < max_weight {
364                    let increase = 1u64 << weights[sym];
365                    if kraft_sum + increase <= target {
366                        kraft_sum += increase;
367                        weights[sym] += 1;
368                    } else {
369                        break;
370                    }
371                }
372            }
373        }
374
375        Some((weights, max_weight))
376    }
377
378    /// Fix code lengths to satisfy Kraft inequality.
379    /// For a valid Huffman code: sum(2^(max_len - len)) = 2^max_len
380    #[allow(dead_code)]
381    fn fix_kraft_inequality(code_lengths: &mut [u8], max_len: u8) {
382        // First, check if we need a deeper tree
383        // Calculate minimum required depth for this many symbols
384        let num_symbols = code_lengths.iter().filter(|&&l| l > 0).count();
385        if num_symbols <= 1 {
386            return;
387        }
388
389        // Calculate current Kraft sum with current max_len
390        let kraft_sum: u64 = code_lengths
391            .iter()
392            .filter(|&&l| l > 0)
393            .map(|&l| 1u64 << (max_len.saturating_sub(l)) as u32)
394            .sum();
395        let target = 1u64 << max_len;
396
397        if kraft_sum <= target {
398            // Already valid or has room to spare - try to fill unused capacity
399            if kraft_sum < target {
400                Self::fill_kraft_capacity(code_lengths, max_len, target - kraft_sum);
401            }
402            return;
403        }
404
405        // Need deeper tree: increase max_len until Kraft sum fits
406        // New max_len must be large enough that 2^new_max_len >= kraft_sum
407        let new_max_len = (64 - kraft_sum.leading_zeros()) as u8;
408        if new_max_len > MAX_WEIGHT {
409            // Can't fix - too many symbols
410            return;
411        }
412
413        // Increase all code lengths by (new_max_len - max_len)
414        let depth_increase = new_max_len - max_len;
415        for len in code_lengths.iter_mut() {
416            if *len > 0 {
417                *len = (*len + depth_increase).min(MAX_WEIGHT);
418            }
419        }
420
421        // Now we have spare capacity, fill it by shortening some codes
422        let new_kraft_sum: u64 = code_lengths
423            .iter()
424            .filter(|&&l| l > 0)
425            .map(|&l| 1u64 << (new_max_len.saturating_sub(l)) as u32)
426            .sum();
427        let new_target = 1u64 << new_max_len;
428
429        if new_kraft_sum < new_target {
430            Self::fill_kraft_capacity(code_lengths, new_max_len, new_target - new_kraft_sum);
431        }
432    }
433
434    /// Fill unused Kraft capacity by shortening some code lengths.
435    #[allow(dead_code)]
436    fn fill_kraft_capacity(code_lengths: &mut [u8], max_len: u8, mut spare: u64) {
437        // Sort symbols by code length (longest first) to shorten long codes
438        let mut syms: Vec<_> = code_lengths
439            .iter()
440            .enumerate()
441            .filter(|&(_, &l)| l > 1)
442            .map(|(i, &l)| (i, l))
443            .collect();
444        syms.sort_by_key(|&(_, l)| std::cmp::Reverse(l));
445
446        for (idx, old_len) in syms {
447            if spare == 0 {
448                break;
449            }
450            // Shortening by 1: contribution goes from 2^(max_len-old_len) to 2^(max_len-old_len+1)
451            // Increase in usage: 2^(max_len-old_len)
452            let increase = 1u64 << (max_len.saturating_sub(old_len)) as u32;
453            if increase <= spare {
454                code_lengths[idx] = old_len - 1;
455                spare -= increase;
456            }
457        }
458    }
459
460    /// Limit code lengths to ensure they satisfy Kraft inequality.
461    /// Uses a simple algorithm to redistribute long codes.
462    #[allow(dead_code)]
463    fn limit_code_lengths(code_lengths: &mut [u8], max_len: u8) {
464        // Count symbols at each length
465        let mut counts = vec![0u32; max_len as usize + 1];
466        for &len in code_lengths.iter() {
467            if len > 0 && len <= max_len {
468                counts[len as usize] += 1;
469            } else if len > max_len {
470                counts[max_len as usize] += 1;
471            }
472        }
473
474        // Clamp all lengths to max_len
475        for len in code_lengths.iter_mut() {
476            if *len > max_len {
477                *len = max_len;
478            }
479        }
480
481        // Adjust to satisfy Kraft: sum(2^-len) <= 1
482        // Equivalently: sum(2^(max_len - len)) <= 2^max_len
483        loop {
484            let kraft_sum: u64 = counts
485                .iter()
486                .enumerate()
487                .skip(1)
488                .map(|(len, &count)| (count as u64) << (max_len as usize - len))
489                .sum();
490
491            let target = 1u64 << max_len;
492            if kraft_sum <= target {
493                break;
494            }
495
496            // Need to reduce: increase some code lengths
497            // Find the shortest non-empty bucket and move one symbol to next bucket
498            for len in 1..max_len as usize {
499                if counts[len] > 0 {
500                    counts[len] -= 1;
501                    counts[len + 1] += 1;
502                    // Update actual code lengths
503                    for code_len in code_lengths.iter_mut() {
504                        if *code_len == len as u8 {
505                            *code_len = (len + 1) as u8;
506                            break;
507                        }
508                    }
509                    break;
510                }
511            }
512        }
513    }
514
515    /// Generate canonical Huffman codes from weights.
516    fn generate_canonical_codes(weights: &[u8], max_bits: u8) -> [HuffmanCode; MAX_SYMBOLS] {
517        let mut codes = [HuffmanCode::default(); MAX_SYMBOLS];
518
519        // Count symbols at each code length
520        let mut bl_count = vec![0u32; max_bits as usize + 2];
521        for &w in weights {
522            if w > 0 {
523                let code_len = (max_bits + 1).saturating_sub(w) as usize;
524                if code_len < bl_count.len() {
525                    bl_count[code_len] += 1;
526                }
527            }
528        }
529
530        // Calculate starting codes for each length
531        let mut next_code = vec![0u32; max_bits as usize + 2];
532        let mut code = 0u32;
533        for (bits, next_code_entry) in next_code
534            .iter_mut()
535            .enumerate()
536            .take(max_bits as usize + 1)
537            .skip(1)
538        {
539            code = (code + bl_count.get(bits - 1).copied().unwrap_or(0)) << 1;
540            *next_code_entry = code;
541        }
542
543        // Assign codes to symbols
544        for (symbol, &w) in weights.iter().enumerate() {
545            if w > 0 && symbol < MAX_SYMBOLS {
546                let code_len = (max_bits + 1).saturating_sub(w) as usize;
547                if code_len < next_code.len() {
548                    codes[symbol] = HuffmanCode::new(next_code[code_len] as u16, code_len as u8);
549                    next_code[code_len] += 1;
550                }
551            }
552        }
553
554        codes
555    }
556
557    /// Encode literals using optimized bit packing.
558    ///
559    /// Uses 64-bit accumulator for efficient byte-aligned writes.
560    /// Optimized with chunked reverse processing and software prefetching
561    /// to maintain cache efficiency despite reverse iteration requirement.
562    ///
563    /// # Performance Optimizations
564    /// - Processes in 64-byte cache-line chunks (reverse chunk order, forward within chunk)
565    /// - Software prefetching brings next chunk into L1 cache ahead of time
566    /// - 64-bit accumulator with branchless 32-bit flushes
567    /// - Unrolled inner loop for better ILP
568    pub fn encode(&self, literals: &[u8]) -> Vec<u8> {
569        if literals.is_empty() {
570            return vec![0x01]; // Just sentinel
571        }
572
573        // Pre-allocate output with better estimate
574        let estimated_bits: usize = literals
575            .iter()
576            .take(256.min(literals.len()))
577            .map(|&b| self.codes[b as usize].num_bits as usize)
578            .sum();
579        let avg_bits = if literals.len() <= 256 {
580            estimated_bits
581        } else {
582            estimated_bits * literals.len() / 256.min(literals.len())
583        };
584        let mut output = Vec::with_capacity(avg_bits.div_ceil(8) + 16);
585
586        // 64-bit accumulator for efficient bit packing
587        let mut accum: u64 = 0;
588        let mut bits_in_accum: u32 = 0;
589
590        // Process in cache-line sized chunks (64 bytes) with prefetching
591        // This maintains cache efficiency despite reverse iteration
592        const CHUNK_SIZE: usize = 64;
593        let len = literals.len();
594        let mut pos = len;
595
596        while pos > 0 {
597            let chunk_start = pos.saturating_sub(CHUNK_SIZE);
598            let chunk_end = pos;
599
600            // Prefetch the NEXT chunk (earlier in memory) into L1 cache
601            // This hides memory latency by fetching ahead
602            #[cfg(target_arch = "x86_64")]
603            if chunk_start >= CHUNK_SIZE {
604                unsafe {
605                    use std::arch::x86_64::{_mm_prefetch, _MM_HINT_T0};
606                    _mm_prefetch(
607                        literals.as_ptr().add(chunk_start - CHUNK_SIZE) as *const i8,
608                        _MM_HINT_T0,
609                    );
610                }
611            }
612
613            // Process bytes within chunk in reverse order
614            // The chunk is now in L1 cache, so reverse iteration is fast
615            let chunk = &literals[chunk_start..chunk_end];
616
617            // Unroll by 4 for better instruction-level parallelism
618            let chunk_len = chunk.len();
619            let mut i = chunk_len;
620
621            // Handle tail (non-multiple of 4)
622            while i > 0 && !i.is_multiple_of(4) {
623                i -= 1;
624                let byte = chunk[i];
625                let code = &self.codes[byte as usize];
626                let num_bits = code.num_bits as u32;
627
628                if num_bits > 0 {
629                    accum |= (code.code as u64) << bits_in_accum;
630                    bits_in_accum += num_bits;
631
632                    if bits_in_accum >= 32 {
633                        output.extend_from_slice(&(accum as u32).to_le_bytes());
634                        accum >>= 32;
635                        bits_in_accum -= 32;
636                    }
637                }
638            }
639
640            // Process 4 bytes at a time (unrolled, branchless)
641            // Novel optimization: Remove all branches in the inner loop.
642            // Since num_bits==0 means the symbol isn't present (code==0, bits==0),
643            // we can unconditionally OR and ADD without changing the result.
644            // This enables better CPU pipelining and SIMD vectorization.
645            while i >= 4 {
646                i -= 4;
647
648                // Load 4 codes (compiler can pipeline these loads)
649                let c0 = self.codes[chunk[i + 3] as usize];
650                let c1 = self.codes[chunk[i + 2] as usize];
651                let c2 = self.codes[chunk[i + 1] as usize];
652                let c3 = self.codes[chunk[i] as usize];
653
654                // Branchless encoding: OR and ADD unconditionally
655                // For valid symbols: adds the code bits
656                // For invalid symbols (num_bits=0): OR 0, ADD 0 - no effect
657                accum |= (c0.code as u64) << bits_in_accum;
658                bits_in_accum += c0.num_bits as u32;
659                accum |= (c1.code as u64) << bits_in_accum;
660                bits_in_accum += c1.num_bits as u32;
661                accum |= (c2.code as u64) << bits_in_accum;
662                bits_in_accum += c2.num_bits as u32;
663                accum |= (c3.code as u64) << bits_in_accum;
664                bits_in_accum += c3.num_bits as u32;
665
666                // Branchless flush: always flush when >= 32 bits
667                // Using conditional move pattern that compilers optimize well
668                if bits_in_accum >= 32 {
669                    output.extend_from_slice(&(accum as u32).to_le_bytes());
670                    accum >>= 32;
671                    bits_in_accum -= 32;
672                }
673                // Second flush for cases where 4 symbols exceed 64 bits total
674                if bits_in_accum >= 32 {
675                    output.extend_from_slice(&(accum as u32).to_le_bytes());
676                    accum >>= 32;
677                    bits_in_accum -= 32;
678                }
679            }
680
681            pos = chunk_start;
682        }
683
684        // Add sentinel bit
685        accum |= 1u64 << bits_in_accum;
686        bits_in_accum += 1;
687
688        // Flush remaining bits (up to 5 bytes: 32 bits max + 1 sentinel)
689        let remaining_bytes = bits_in_accum.div_ceil(8);
690        for _ in 0..remaining_bytes {
691            output.push((accum & 0xFF) as u8);
692            accum >>= 8;
693        }
694
695        output
696    }
697
698    /// Encode literals in batches for better throughput.
699    ///
700    /// Processes 4 symbols at a time when possible.
701    #[allow(dead_code)]
702    pub fn encode_batch(&self, literals: &[u8]) -> Vec<u8> {
703        if literals.len() < 8 {
704            return self.encode(literals);
705        }
706
707        let mut output = Vec::with_capacity(literals.len() / 2 + 8);
708        let mut accum: u64 = 0;
709        let mut bits_in_accum: u32 = 0;
710
711        // Process in reverse, 4 symbols at a time
712        let len = literals.len();
713        let mut i = len;
714
715        // Handle tail (last 1-3 symbols)
716        while i > 0 && !i.is_multiple_of(4) {
717            i -= 1;
718            let code = &self.codes[literals[i] as usize];
719            if code.num_bits > 0 {
720                accum |= (code.code as u64) << bits_in_accum;
721                bits_in_accum += code.num_bits as u32;
722                if bits_in_accum >= 8 {
723                    output.push((accum & 0xFF) as u8);
724                    accum >>= 8;
725                    bits_in_accum -= 8;
726                }
727            }
728        }
729
730        // Process 4 symbols at a time
731        while i >= 4 {
732            i -= 4;
733
734            // Load 4 codes
735            let c0 = &self.codes[literals[i + 3] as usize];
736            let c1 = &self.codes[literals[i + 2] as usize];
737            let c2 = &self.codes[literals[i + 1] as usize];
738            let c3 = &self.codes[literals[i] as usize];
739
740            // Accumulate codes
741            accum |= (c0.code as u64) << bits_in_accum;
742            bits_in_accum += c0.num_bits as u32;
743            accum |= (c1.code as u64) << bits_in_accum;
744            bits_in_accum += c1.num_bits as u32;
745            accum |= (c2.code as u64) << bits_in_accum;
746            bits_in_accum += c2.num_bits as u32;
747            accum |= (c3.code as u64) << bits_in_accum;
748            bits_in_accum += c3.num_bits as u32;
749
750            // Flush complete bytes
751            while bits_in_accum >= 8 {
752                output.push((accum & 0xFF) as u8);
753                accum >>= 8;
754                bits_in_accum -= 8;
755            }
756        }
757
758        // Handle remaining symbols
759        while i > 0 {
760            i -= 1;
761            let code = &self.codes[literals[i] as usize];
762            if code.num_bits > 0 {
763                accum |= (code.code as u64) << bits_in_accum;
764                bits_in_accum += code.num_bits as u32;
765                if bits_in_accum >= 8 {
766                    output.push((accum & 0xFF) as u8);
767                    accum >>= 8;
768                    bits_in_accum -= 8;
769                }
770            }
771        }
772
773        // Add sentinel bit
774        accum |= 1u64 << bits_in_accum;
775        bits_in_accum += 1;
776
777        // Flush remaining
778        if bits_in_accum > 0 {
779            output.push((accum & 0xFF) as u8);
780        }
781
782        output
783    }
784
785    /// Serialize weights in Zstd format (direct or FSE-compressed).
786    ///
787    /// For num_symbols <= 128: Uses direct format
788    /// - header_byte = (num_symbols - 1) + 128
789    /// - Followed by ceil(num_symbols / 2) bytes of 4-bit weights
790    ///
791    /// For num_symbols > 128: Uses FSE-compressed format
792    /// - header_byte < 128 = compressed_size
793    /// - Followed by FSE table and compressed weights
794    pub fn serialize_weights(&self) -> Vec<u8> {
795        // Find last non-zero weight
796        let last_symbol = self
797            .weights
798            .iter()
799            .enumerate()
800            .filter(|&(_, w)| *w > 0)
801            .map(|(i, _)| i)
802            .max()
803            .unwrap_or(0);
804
805        let num_symbols = last_symbol + 1;
806
807        // Calculate direct encoding size
808        let direct_size = 1 + num_symbols.div_ceil(2);
809
810        // Try FSE-compressed weights if beneficial
811        // FSE is typically better when there are many zeros in the weight table
812        // (sparse symbol usage like ASCII text)
813        if num_symbols > 32 {
814            let fse_result = self.serialize_weights_fse(num_symbols);
815            if !fse_result.is_empty() && fse_result.len() < direct_size {
816                return fse_result;
817            }
818        }
819
820        // For >128 symbols, FSE is required
821        if num_symbols > 128 {
822            let fse_result = self.serialize_weights_fse(num_symbols);
823            if !fse_result.is_empty() {
824                return fse_result;
825            }
826            // FSE encoding failed, fall back to empty (caller should use raw block)
827            return Vec::new();
828        }
829
830        // Direct encoding for <= 128 symbols
831        let mut output = Vec::with_capacity(direct_size);
832
833        if num_symbols > 0 {
834            output.push(((num_symbols - 1) + 128) as u8);
835
836            // Pack weights as 4-bit nibbles
837            // Our decoder expects: Weight[i] in high nibble, Weight[i+1] in low nibble
838            for i in (0..num_symbols).step_by(2) {
839                let w1 = self.weights.get(i).copied().unwrap_or(0);
840                let w2 = self.weights.get(i + 1).copied().unwrap_or(0);
841                output.push((w1 << 4) | (w2 & 0x0F));
842            }
843        }
844
845        output
846    }
847
848    /// Serialize weights using FSE compression for >128 symbols.
849    ///
850    /// Per RFC 8878 Section 4.2.1.1:
851    /// - header_byte < 128 indicates FSE-compressed weights
852    /// - header_byte value is the compressed size in bytes
853    /// - Weights are encoded using an FSE table with max_symbol = 12 (weights 0-12)
854    ///
855    /// The FSE bitstream format for Huffman weights:
856    /// 1. FSE table header (accuracy_log + probabilities)
857    /// 2. Compressed bitstream read in reverse (from end with sentinel)
858    ///    - Initial decoder state (accuracy_log bits, MSB-first from end)
859    ///    - Encoded symbols' bits for state transitions
860    fn serialize_weights_fse(&self, num_symbols: usize) -> Vec<u8> {
861        // Count frequency of each weight value (weights are 0-11)
862        let mut weight_freq = [0i16; 13]; // 0-12 possible weight values
863        for i in 0..num_symbols {
864            let w = self.weights.get(i).copied().unwrap_or(0) as usize;
865            if w <= 12 {
866                weight_freq[w] += 1;
867            }
868        }
869
870        // Choose accuracy_log (6 is typical for Huffman weights per RFC 8878)
871        const WEIGHT_ACCURACY_LOG: u8 = 6;
872        let table_size = 1i16 << WEIGHT_ACCURACY_LOG;
873
874        // Normalize frequencies to sum to table_size
875        let total: i16 = weight_freq.iter().sum();
876        if total == 0 {
877            return Vec::new(); // No weights to encode
878        }
879
880        let mut normalized = [0i16; 13];
881        let mut remaining = table_size;
882
883        // First pass: assign proportional counts
884        for (i, &freq) in weight_freq.iter().enumerate() {
885            if freq > 0 {
886                let norm = ((freq as i32 * table_size as i32) / total as i32).max(1) as i16;
887                normalized[i] = norm;
888                remaining -= norm;
889            }
890        }
891
892        // Distribute remaining capacity to largest frequencies
893        while remaining > 0 {
894            let mut best_idx = 0;
895            let mut best_freq = 0;
896            for (i, &freq) in weight_freq.iter().enumerate() {
897                if freq > best_freq && normalized[i] > 0 {
898                    best_freq = freq;
899                    best_idx = i;
900                }
901            }
902            if best_freq == 0 {
903                break;
904            }
905            normalized[best_idx] += 1;
906            remaining -= 1;
907        }
908
909        // Handle over-allocation (can happen due to rounding)
910        while remaining < 0 {
911            let mut best_idx = 0;
912            let mut best_norm = 0;
913            for (i, &norm) in normalized.iter().enumerate() {
914                if norm > 1 && norm > best_norm {
915                    best_norm = norm;
916                    best_idx = i;
917                }
918            }
919            if best_norm <= 1 {
920                break;
921            }
922            normalized[best_idx] -= 1;
923            remaining += 1;
924        }
925
926        // Build FSE table from normalized frequencies
927        let fse_table = match FseTable::build(&normalized, WEIGHT_ACCURACY_LOG, 12) {
928            Ok(t) => t,
929            Err(_) => return Vec::new(), // Failed to build table
930        };
931
932        // Serialize FSE table header
933        let table_header = Self::serialize_fse_table_header(&normalized, WEIGHT_ACCURACY_LOG);
934
935        // For FSE encoding, we use a simulation-based approach:
936        // 1. Find the sequence of decoder states that produces our weight sequence
937        // 2. Work backwards to compute the bits needed for each transition
938        //
939        // The decoder works as:
940        //   state → (symbol, baseline, num_bits)
941        //   next_state = baseline + read_bits(num_bits)
942        //
943        // So for encoding, we need to find states s0, s1, ... such that:
944        //   table[s0].symbol = weight[0]
945        //   table[s1].symbol = weight[1], and s1 = table[s0].baseline + bits0
946        //   etc.
947
948        // Collect weights to encode
949        let weights_to_encode: Vec<u8> = (0..num_symbols)
950            .map(|i| self.weights.get(i).copied().unwrap_or(0))
951            .collect();
952
953        // Find valid decoder state sequence
954        // For each weight value, find all states that decode to it
955        let mut states_for_symbol: [Vec<usize>; 13] = Default::default();
956        for state in 0..fse_table.size() {
957            let entry = fse_table.decode(state);
958            if (entry.symbol as usize) < 13 {
959                states_for_symbol[entry.symbol as usize].push(state);
960            }
961        }
962
963        // Check if all weight values have at least one state
964        for &w in &weights_to_encode {
965            if states_for_symbol[w as usize].is_empty() {
966                return Vec::new(); // Can't encode this weight
967            }
968        }
969
970        // Use greedy approach: for each symbol, pick a state that works
971        // and compute the bits needed for the transition from the previous state
972        let mut state_sequence = Vec::with_capacity(num_symbols);
973        let mut bits_sequence: Vec<(u32, u8)> = Vec::with_capacity(num_symbols);
974
975        // First state: pick any state for the first weight
976        let first_weight = weights_to_encode[0] as usize;
977        let first_state = states_for_symbol[first_weight][0];
978        state_sequence.push(first_state);
979
980        // For each subsequent weight, find a state and compute transition bits
981        for i in 1..num_symbols {
982            let prev_state = state_sequence[i - 1];
983            let prev_entry = fse_table.decode(prev_state);
984            let target_weight = weights_to_encode[i] as usize;
985
986            // We need: next_state = baseline + bits
987            // where table[next_state].symbol = target_weight
988            // and bits < (1 << num_bits)
989            let baseline = prev_entry.baseline as usize;
990            let num_bits = prev_entry.num_bits;
991            let max_bits_value = 1usize << num_bits;
992
993            // Find a state for target_weight that can be reached
994            let mut found = false;
995            for &candidate_state in &states_for_symbol[target_weight] {
996                if candidate_state >= baseline && candidate_state < baseline + max_bits_value {
997                    let bits = (candidate_state - baseline) as u32;
998                    bits_sequence.push((bits, num_bits));
999                    state_sequence.push(candidate_state);
1000                    found = true;
1001                    break;
1002                }
1003            }
1004
1005            if !found {
1006                // Try wrapping around by using a different previous state
1007                // This is a simplification - full implementation would backtrack
1008                return Vec::new(); // Can't find valid encoding path
1009            }
1010        }
1011
1012        // Now build the bitstream
1013        // The decoder reads:
1014        // 1. Initial state (accuracy_log bits) - this is state_sequence[0]
1015        // 2. For each symbol after the first, read bits for next state
1016        // 3. Final symbol is decoded from current state without reading more bits
1017        //
1018        // The bitstream is read in reverse (MSB-first from end).
1019        // So we write: [transition bits...][initial_state][sentinel]
1020        // And the bytes need to be arranged so that reversed reading works.
1021
1022        // Build forward bitstream (we'll handle reversal through the writer)
1023        let mut bit_writer = FseBitWriter::new();
1024
1025        // Write transition bits in order (they'll be read in reverse)
1026        // But wait - the reversed reader reads from the end, so the LAST bits
1027        // written should be read FIRST (as initial state).
1028        //
1029        // We need:
1030        // - Write initial_state last (so it's at the end, read first)
1031        // - Write transition bits before that
1032        //
1033        // Current approach: write bits in reverse order of how decoder reads
1034        // Decoder reads: init_state, then bits for s1, bits for s2, ...
1035        // We write: bits for s_{n-1}, bits for s_{n-2}, ..., bits for s1, init_state
1036
1037        // Write transition bits in reverse order
1038        for i in (0..bits_sequence.len()).rev() {
1039            let (bits, num_bits) = bits_sequence[i];
1040            bit_writer.write_bits(bits, num_bits);
1041        }
1042
1043        // Write initial state (will be read first by decoder)
1044        bit_writer.write_bits(state_sequence[0] as u32, WEIGHT_ACCURACY_LOG);
1045
1046        // Finish bitstream (adds sentinel)
1047        let mut compressed_stream = bit_writer.finish();
1048
1049        // The FseBitWriter produces bits in LSB-first order within bytes,
1050        // but the reversed reader reads MSB-first. We need to bit-reverse each byte.
1051        for byte in &mut compressed_stream {
1052            *byte = byte.reverse_bits();
1053        }
1054
1055        // Combine: FSE table header + compressed stream
1056        let total_compressed_size = table_header.len() + compressed_stream.len();
1057
1058        // Check if compressed size fits in header byte (< 128)
1059        if total_compressed_size >= 128 {
1060            return Vec::new(); // Too large for FSE format
1061        }
1062
1063        // Build final output
1064        let mut output = Vec::with_capacity(1 + total_compressed_size);
1065        output.push(total_compressed_size as u8); // header < 128 = FSE compressed
1066        output.extend_from_slice(&table_header);
1067        output.extend_from_slice(&compressed_stream);
1068
1069        output
1070    }
1071
1072    /// Serialize FSE table header for Huffman weights.
1073    ///
1074    /// Format: 4-bit accuracy_log + variable-length probabilities
1075    #[allow(dead_code)]
1076    fn serialize_fse_table_header(normalized: &[i16; 13], accuracy_log: u8) -> Vec<u8> {
1077        let mut output = Vec::with_capacity(16);
1078        let mut bit_pos = 0u32;
1079        let mut accum = 0u64;
1080
1081        // Write accuracy_log - 5 (4 bits)
1082        let acc_val = (accuracy_log.saturating_sub(5)) as u64;
1083        accum |= acc_val << bit_pos;
1084        bit_pos += 4;
1085
1086        // Write probabilities using variable-length encoding
1087        let table_size = 1i32 << accuracy_log;
1088        let mut remaining = table_size;
1089
1090        for &prob in normalized.iter() {
1091            if remaining <= 0 {
1092                break;
1093            }
1094
1095            // Calculate bits needed to encode this probability
1096            let max_bits = 32 - (remaining + 1).leading_zeros();
1097            let threshold = (1i32 << max_bits) - 1 - remaining;
1098
1099            // Encode probability
1100            let prob_val = if prob == -1 { 0 } else { prob as i32 };
1101
1102            if prob_val < threshold {
1103                // Small value: use max_bits - 1 bits
1104                accum |= (prob_val as u64) << bit_pos;
1105                bit_pos += max_bits - 1;
1106            } else {
1107                // Large value: use max_bits bits
1108                let large = prob_val + threshold;
1109                accum |= (large as u64) << bit_pos;
1110                bit_pos += max_bits;
1111            }
1112
1113            // Flush complete bytes
1114            while bit_pos >= 8 {
1115                output.push((accum & 0xFF) as u8);
1116                accum >>= 8;
1117                bit_pos -= 8;
1118            }
1119
1120            // Update remaining
1121            if prob == -1 {
1122                remaining -= 1;
1123            } else {
1124                remaining -= prob as i32;
1125            }
1126        }
1127
1128        // Flush remaining bits
1129        if bit_pos > 0 {
1130            output.push((accum & 0xFF) as u8);
1131        }
1132
1133        output
1134    }
1135
1136    /// Get maximum code length.
1137    #[inline]
1138    pub fn max_bits(&self) -> u8 {
1139        self.max_bits
1140    }
1141
1142    /// Get number of symbols with codes.
1143    #[inline]
1144    pub fn num_symbols(&self) -> usize {
1145        self.num_symbols
1146    }
1147
1148    /// Estimate compressed size.
1149    pub fn estimate_size(&self, literals: &[u8]) -> usize {
1150        let mut total_bits: usize = 0;
1151        for &byte in literals {
1152            total_bits += self.codes[byte as usize].num_bits as usize;
1153        }
1154        // Weight table size depends on last_symbol (highest symbol index), not unique count
1155        // Direct encoding uses (last_symbol + 1) symbols in the table
1156        let num_table_symbols = self.last_symbol + 1;
1157        let weight_table_size = 1 + num_table_symbols.div_ceil(2);
1158        total_bits.div_ceil(8) + weight_table_size
1159    }
1160
1161    /// Get code for a symbol (for testing).
1162    #[cfg(test)]
1163    pub fn get_codes(&self) -> &[HuffmanCode; MAX_SYMBOLS] {
1164        &self.codes
1165    }
1166}
1167
1168// =============================================================================
1169// Tests
1170// =============================================================================
1171
1172#[cfg(test)]
1173mod tests {
1174    use super::*;
1175
1176    #[test]
1177    fn test_build_simple() {
1178        let mut data = Vec::new();
1179        for _ in 0..100 {
1180            data.push(b'a');
1181        }
1182        for _ in 0..50 {
1183            data.push(b'b');
1184        }
1185        for _ in 0..25 {
1186            data.push(b'c');
1187        }
1188
1189        let encoder = HuffmanEncoder::build(&data);
1190        assert!(encoder.is_some());
1191
1192        let encoder = encoder.unwrap();
1193        assert!(encoder.num_symbols() >= 3);
1194    }
1195
1196    #[test]
1197    fn test_build_too_small() {
1198        let data = b"small";
1199        let encoder = HuffmanEncoder::build(data);
1200        assert!(encoder.is_none());
1201    }
1202
1203    #[test]
1204    fn test_encode_simple() {
1205        let mut data = Vec::new();
1206        for _ in 0..100 {
1207            data.push(b'a');
1208        }
1209        for _ in 0..50 {
1210            data.push(b'b');
1211        }
1212
1213        let encoder = HuffmanEncoder::build(&data);
1214        if let Some(enc) = encoder {
1215            let compressed = enc.encode(&data);
1216            assert!(compressed.len() < data.len());
1217        }
1218    }
1219
1220    #[test]
1221    fn test_encode_batch() {
1222        let mut data = Vec::new();
1223        for _ in 0..100 {
1224            data.push(b'a');
1225        }
1226        for _ in 0..50 {
1227            data.push(b'b');
1228        }
1229        for _ in 0..25 {
1230            data.push(b'c');
1231        }
1232
1233        let encoder = HuffmanEncoder::build(&data);
1234        if let Some(enc) = encoder {
1235            let regular = enc.encode(&data);
1236            let batch = enc.encode_batch(&data);
1237
1238            // Both should produce valid compressed data
1239            assert!(!regular.is_empty());
1240            assert!(!batch.is_empty());
1241        }
1242    }
1243
1244    #[test]
1245    fn test_serialize_weights() {
1246        let mut data = Vec::new();
1247        for _ in 0..100 {
1248            data.push(b'a');
1249        }
1250        for _ in 0..50 {
1251            data.push(b'b');
1252        }
1253
1254        let encoder = HuffmanEncoder::build(&data);
1255        if let Some(enc) = encoder {
1256            let weights = enc.serialize_weights();
1257            assert!(!weights.is_empty());
1258            assert!(weights[0] >= 128); // Direct format
1259        }
1260    }
1261
1262    #[test]
1263    fn test_estimate_size() {
1264        let mut data = Vec::new();
1265        for _ in 0..100 {
1266            data.push(b'a');
1267        }
1268        for _ in 0..50 {
1269            data.push(b'b');
1270        }
1271
1272        let encoder = HuffmanEncoder::build(&data);
1273        if let Some(enc) = encoder {
1274            let estimated = enc.estimate_size(&data);
1275            let actual = enc.encode(&data).len() + enc.serialize_weights().len();
1276            assert!(estimated <= actual + 10);
1277        }
1278    }
1279
1280    #[test]
1281    fn test_frequency_counting() {
1282        let data = vec![0u8, 1, 2, 0, 1, 0, 0, 0, 1, 2, 3];
1283        let freq = HuffmanEncoder::count_frequencies(&data);
1284
1285        assert_eq!(freq[0], 5);
1286        assert_eq!(freq[1], 3);
1287        assert_eq!(freq[2], 2);
1288        assert_eq!(freq[3], 1);
1289    }
1290
1291    #[test]
1292    fn test_huffman_code_alignment() {
1293        // Verify HuffmanCode is properly aligned
1294        assert_eq!(std::mem::size_of::<HuffmanCode>(), 4);
1295        assert_eq!(std::mem::align_of::<HuffmanCode>(), 4);
1296    }
1297
1298    #[test]
1299    fn test_many_symbols_uses_direct_encoding() {
1300        // Test with many unique symbols (but <= 128)
1301        // Create data with 100 unique symbols
1302        let mut data = Vec::new();
1303        for sym in 0..100u8 {
1304            for _ in 0..(100 - sym as usize).max(1) {
1305                data.push(sym);
1306            }
1307        }
1308
1309        let encoder = HuffmanEncoder::build(&data);
1310        assert!(encoder.is_some(), "Should build encoder for 100 symbols");
1311
1312        if let Some(enc) = encoder {
1313            let weights = enc.serialize_weights();
1314            assert!(!weights.is_empty(), "Should serialize weights");
1315            // Should use direct encoding (header >= 128)
1316            assert!(
1317                weights[0] >= 128,
1318                "Should use direct format for <= 128 symbols"
1319            );
1320        }
1321    }
1322
1323    #[test]
1324    fn test_fse_table_header_serialization() {
1325        // Test the FSE table header serialization format
1326        let normalized = [32i16, 16, 8, 4, 2, 1, 1, 0, 0, 0, 0, 0, 0];
1327        let header = HuffmanEncoder::serialize_fse_table_header(&normalized, 6);
1328
1329        // Header should not be empty
1330        assert!(!header.is_empty());
1331
1332        // First 4 bits should be accuracy_log - 5 = 1
1333        assert_eq!(header[0] & 0x0F, 1);
1334    }
1335}