haagenti_zstd/huffman/encoder.rs
1//! Huffman encoding for Zstd literals.
2//!
3//! This module implements high-performance Huffman encoding for Zstd compression.
4//!
5//! ## Optimizations
6//!
7//! - SIMD-accelerated frequency counting (histogram)
8//! - 64-bit accumulator for efficient bit packing
9//! - Cache-friendly code table layout
10//! - Vectorized encoding for batch processing
11//!
12//! ## Weight System
13//!
14//! In Zstd Huffman encoding:
15//! - Weight `w > 0` means `code_length = max_bits + 1 - w`
16//! - Weight `0` means symbol is not present
17//! - Higher weight = shorter code = more frequent symbol
18//! - Maximum weight is 11 (minimum code length = 1 bit)
19//!
20//! ## References
21//!
22//! - [RFC 8878 Section 4.2](https://datatracker.ietf.org/doc/html/rfc8878#section-4.2)
23
24use crate::fse::{FseBitWriter, FseTable};
25
26/// Maximum number of symbols for Huffman encoding (256 for bytes).
27const MAX_SYMBOLS: usize = 256;
28
29/// Maximum Huffman weight (limits code length).
30const MAX_WEIGHT: u8 = 11;
31
32/// Minimum data size to benefit from Huffman encoding.
33const MIN_HUFFMAN_SIZE: usize = 32;
34
35/// Huffman encoding table entry - packed for cache efficiency.
36#[derive(Debug, Clone, Copy, Default)]
37#[repr(C, align(4))]
38pub struct HuffmanCode {
39 /// The Huffman code bits (stored in LSB).
40 pub code: u16,
41 /// Number of bits in the code.
42 pub num_bits: u8,
43 /// Padding for alignment.
44 _pad: u8,
45}
46
47impl HuffmanCode {
48 #[inline]
49 const fn new(code: u16, num_bits: u8) -> Self {
50 Self {
51 code,
52 num_bits,
53 _pad: 0,
54 }
55 }
56}
57
58/// Optimized Huffman encoder for literal compression.
59#[derive(Debug)]
60pub struct HuffmanEncoder {
61 /// Encoding table: symbol -> code (256 entries, cache-aligned)
62 codes: Box<[HuffmanCode; MAX_SYMBOLS]>,
63 /// Symbol weights for serialization
64 weights: Vec<u8>,
65 /// Maximum code length in bits
66 max_bits: u8,
67 /// Number of symbols with non-zero weight
68 num_symbols: usize,
69 /// Highest symbol index with non-zero weight (for weight table sizing)
70 last_symbol: usize,
71}
72
73impl HuffmanEncoder {
74 /// Build a Huffman encoder from literal data.
75 ///
76 /// Uses SIMD-accelerated histogram when available.
77 /// Returns None if data cannot be efficiently Huffman-compressed.
78 pub fn build(data: &[u8]) -> Option<Self> {
79 if data.len() < MIN_HUFFMAN_SIZE {
80 return None;
81 }
82
83 // Count symbol frequencies using optimized histogram
84 let freq = Self::count_frequencies(data);
85
86 // Count unique symbols and find last symbol with non-zero frequency
87 let unique_count = freq.iter().filter(|&&f| f > 0).count();
88 if unique_count < 2 {
89 return None; // Use RLE instead
90 }
91
92 let last_symbol = freq
93 .iter()
94 .enumerate()
95 .filter(|&(_, &f)| f > 0)
96 .map(|(i, _)| i)
97 .max()
98 .unwrap_or(0);
99
100 // Convert frequencies to weights
101 let (weights, max_bits) = Self::frequencies_to_weights(&freq)?;
102
103 // Generate canonical codes
104 let codes = Self::generate_canonical_codes(&weights, max_bits);
105
106 Some(Self {
107 codes: Box::new(codes),
108 weights,
109 max_bits,
110 num_symbols: unique_count,
111 last_symbol,
112 })
113 }
114
115 /// Build a Huffman encoder from pre-defined weights.
116 ///
117 /// This allows using custom Huffman tables instead of building from data.
118 /// Useful when you have pre-trained weights from dictionary compression
119 /// or want to reuse weights across multiple blocks.
120 ///
121 /// # Parameters
122 ///
123 /// - `weights`: Array of 256 weights (one per byte value). Weight 0 means
124 /// symbol is not present. Weight w > 0 means code_length = max_bits + 1 - w.
125 ///
126 /// # Returns
127 ///
128 /// Returns `Some(encoder)` if the weights are valid, `None` otherwise.
129 ///
130 /// # Example
131 ///
132 /// ```rust
133 /// use haagenti_zstd::huffman::HuffmanEncoder;
134 ///
135 /// // Define weights for symbols 'a' (97), 'b' (98), 'c' (99)
136 /// let mut weights = vec![0u8; 256];
137 /// weights[97] = 3; // 'a' - highest weight (shortest code)
138 /// weights[98] = 2; // 'b' - medium weight
139 /// weights[99] = 1; // 'c' - lowest weight (longest code)
140 ///
141 /// let encoder = HuffmanEncoder::from_weights(&weights).unwrap();
142 /// ```
143 pub fn from_weights(weights: &[u8]) -> Option<Self> {
144 if weights.len() != MAX_SYMBOLS {
145 return None;
146 }
147
148 // Count unique symbols and find last symbol with non-zero weight
149 let unique_count = weights.iter().filter(|&&w| w > 0).count();
150 if unique_count < 2 {
151 return None; // Need at least 2 symbols
152 }
153
154 let last_symbol = weights
155 .iter()
156 .enumerate()
157 .filter(|&(_, &w)| w > 0)
158 .map(|(i, _)| i)
159 .max()
160 .unwrap_or(0);
161
162 // Find max weight to determine max_bits
163 let max_weight = *weights.iter().max().unwrap_or(&0);
164 if max_weight == 0 || max_weight > MAX_WEIGHT {
165 return None;
166 }
167
168 // Calculate max_bits from max_weight
169 // In Zstd: code_length = max_bits + 1 - weight
170 // For the highest weight symbol, code_length should be 1, so:
171 // max_bits = max_weight
172 let max_bits = max_weight;
173
174 // Generate canonical codes from weights
175 let codes = Self::generate_canonical_codes(weights, max_bits);
176
177 Some(Self {
178 codes: Box::new(codes),
179 weights: weights.to_vec(),
180 max_bits,
181 num_symbols: unique_count,
182 last_symbol,
183 })
184 }
185
186 /// Count byte frequencies using optimized histogram.
187 ///
188 /// Uses SIMD acceleration when available via haagenti_simd.
189 #[inline]
190 fn count_frequencies(data: &[u8]) -> [u32; MAX_SYMBOLS] {
191 // Use SIMD-accelerated histogram when feature is enabled
192 #[cfg(feature = "simd")]
193 {
194 haagenti_simd::byte_histogram(data)
195 }
196
197 // Optimized scalar fallback using 4-way interleaved counting
198 // This reduces cache line conflicts from histogram updates
199 #[cfg(not(feature = "simd"))]
200 {
201 let mut freq0 = [0u32; MAX_SYMBOLS];
202 let mut freq1 = [0u32; MAX_SYMBOLS];
203 let mut freq2 = [0u32; MAX_SYMBOLS];
204 let mut freq3 = [0u32; MAX_SYMBOLS];
205
206 // Process 16 bytes at a time with 4 interleaved histograms
207 let chunks = data.chunks_exact(16);
208 let remainder = chunks.remainder();
209
210 for chunk in chunks {
211 // Interleave to reduce pipeline stalls from same-address increments
212 freq0[chunk[0] as usize] += 1;
213 freq1[chunk[1] as usize] += 1;
214 freq2[chunk[2] as usize] += 1;
215 freq3[chunk[3] as usize] += 1;
216 freq0[chunk[4] as usize] += 1;
217 freq1[chunk[5] as usize] += 1;
218 freq2[chunk[6] as usize] += 1;
219 freq3[chunk[7] as usize] += 1;
220 freq0[chunk[8] as usize] += 1;
221 freq1[chunk[9] as usize] += 1;
222 freq2[chunk[10] as usize] += 1;
223 freq3[chunk[11] as usize] += 1;
224 freq0[chunk[12] as usize] += 1;
225 freq1[chunk[13] as usize] += 1;
226 freq2[chunk[14] as usize] += 1;
227 freq3[chunk[15] as usize] += 1;
228 }
229
230 // Handle remainder
231 for &byte in remainder {
232 freq0[byte as usize] += 1;
233 }
234
235 // Merge the 4 histograms
236 for i in 0..MAX_SYMBOLS {
237 freq0[i] += freq1[i] + freq2[i] + freq3[i];
238 }
239
240 freq0
241 }
242 }
243
244 /// Convert frequencies to Zstd Huffman weights.
245 ///
246 /// Produces weights that satisfy the Kraft inequality:
247 /// sum(2^weight) = 2^(max_weight + 1)
248 ///
249 /// # Algorithm Complexity: O(n log n)
250 ///
251 /// 1. Sort symbols by frequency: O(n log n)
252 /// 2. Calculate initial weights based on frequency ratios: O(n)
253 /// 3. Adjust weights to fill Kraft capacity using heap-based greedy: O(n log n)
254 ///
255 /// This replaces the previous O(n²) algorithm that used repeated full scans.
256 fn frequencies_to_weights(freq: &[u32; MAX_SYMBOLS]) -> Option<(Vec<u8>, u8)> {
257 // Collect non-zero frequency symbols
258 let mut symbols: Vec<(usize, u32)> = freq
259 .iter()
260 .enumerate()
261 .filter(|&(_, &f)| f > 0)
262 .map(|(i, &f)| (i, f))
263 .collect();
264
265 if symbols.len() < 2 {
266 return None;
267 }
268
269 let n = symbols.len();
270
271 // Special case: exactly 2 symbols get weight 1 each (1-bit codes)
272 if n == 2 {
273 let mut weights = vec![0u8; MAX_SYMBOLS];
274 weights[symbols[0].0] = 1;
275 weights[symbols[1].0] = 1;
276 return Some((weights, 1));
277 }
278
279 // Sort symbols by frequency (highest first) - O(n log n)
280 symbols.sort_unstable_by(|a, b| b.1.cmp(&a.1));
281
282 // Calculate max_weight needed for n symbols
283 let min_exp = if n <= 2 {
284 0
285 } else {
286 64 - ((n - 1) as u64).leading_zeros()
287 };
288 let max_weight = ((min_exp + 1) as u8).clamp(1, MAX_WEIGHT);
289
290 let mut weights = vec![0u8; MAX_SYMBOLS];
291 let target = 1u64 << (max_weight + 1);
292
293 // Phase 1: Assign initial weights based on frequency ratio - O(n)
294 // Use log2(max_freq / freq) to estimate relative code lengths
295 let max_freq = symbols[0].1 as u64;
296
297 for (idx, &(sym, freq)) in symbols.iter().enumerate() {
298 if idx == 0 {
299 // Most frequent symbol gets max_weight (shortest code)
300 weights[sym] = max_weight;
301 } else {
302 // Calculate weight based on frequency ratio
303 // Higher ratio = lower frequency = lower weight = longer code
304 let ratio = (max_freq + freq as u64 - 1) / freq.max(1) as u64;
305 let log_ratio = if ratio <= 1 {
306 0
307 } else {
308 (64 - ratio.leading_zeros()).saturating_sub(1) as u8
309 };
310 // Clamp to valid range [1, max_weight]
311 let w = max_weight.saturating_sub(log_ratio).max(1);
312 weights[sym] = w;
313 }
314 }
315
316 // Calculate current Kraft sum - O(n)
317 let mut kraft_sum: u64 = symbols.iter().map(|(sym, _)| 1u64 << weights[*sym]).sum();
318
319 // Phase 2: Adjust weights to satisfy Kraft inequality - O(n log n) worst case
320 // Use a greedy approach: process symbols by weight (lowest first for increasing)
321
322 if kraft_sum < target {
323 // Under capacity: increase weights for symbols (shorter codes)
324 // Process from lowest weight to highest (most room to increase)
325 let mut by_weight: Vec<(usize, u8)> = symbols
326 .iter()
327 .map(|&(sym, _)| (sym, weights[sym]))
328 .collect();
329 by_weight.sort_unstable_by_key(|&(_, w)| w);
330
331 for (sym, _) in by_weight {
332 while weights[sym] < max_weight && kraft_sum < target {
333 let increase = 1u64 << weights[sym];
334 if kraft_sum + increase <= target {
335 kraft_sum += increase;
336 weights[sym] += 1;
337 } else {
338 break;
339 }
340 }
341 }
342 } else if kraft_sum > target {
343 // Over capacity: decrease weights (longer codes)
344 // Process from highest weight to lowest
345 let mut by_weight: Vec<(usize, u8)> = symbols
346 .iter()
347 .map(|&(sym, _)| (sym, weights[sym]))
348 .collect();
349 by_weight.sort_unstable_by_key(|&(_, w)| std::cmp::Reverse(w));
350
351 for (sym, _) in by_weight {
352 while weights[sym] > 1 && kraft_sum > target {
353 weights[sym] -= 1;
354 kraft_sum -= 1u64 << weights[sym];
355 }
356 }
357 }
358
359 // Final pass: fill any remaining capacity - O(n)
360 // This handles edge cases where the above didn't fully utilize capacity
361 if kraft_sum < target {
362 for &(sym, _) in &symbols {
363 while weights[sym] < max_weight {
364 let increase = 1u64 << weights[sym];
365 if kraft_sum + increase <= target {
366 kraft_sum += increase;
367 weights[sym] += 1;
368 } else {
369 break;
370 }
371 }
372 }
373 }
374
375 Some((weights, max_weight))
376 }
377
378 /// Fix code lengths to satisfy Kraft inequality.
379 /// For a valid Huffman code: sum(2^(max_len - len)) = 2^max_len
380 #[allow(dead_code)]
381 fn fix_kraft_inequality(code_lengths: &mut [u8], max_len: u8) {
382 // First, check if we need a deeper tree
383 // Calculate minimum required depth for this many symbols
384 let num_symbols = code_lengths.iter().filter(|&&l| l > 0).count();
385 if num_symbols <= 1 {
386 return;
387 }
388
389 // Calculate current Kraft sum with current max_len
390 let kraft_sum: u64 = code_lengths
391 .iter()
392 .filter(|&&l| l > 0)
393 .map(|&l| 1u64 << (max_len.saturating_sub(l)) as u32)
394 .sum();
395 let target = 1u64 << max_len;
396
397 if kraft_sum <= target {
398 // Already valid or has room to spare - try to fill unused capacity
399 if kraft_sum < target {
400 Self::fill_kraft_capacity(code_lengths, max_len, target - kraft_sum);
401 }
402 return;
403 }
404
405 // Need deeper tree: increase max_len until Kraft sum fits
406 // New max_len must be large enough that 2^new_max_len >= kraft_sum
407 let new_max_len = (64 - kraft_sum.leading_zeros()) as u8;
408 if new_max_len > MAX_WEIGHT {
409 // Can't fix - too many symbols
410 return;
411 }
412
413 // Increase all code lengths by (new_max_len - max_len)
414 let depth_increase = new_max_len - max_len;
415 for len in code_lengths.iter_mut() {
416 if *len > 0 {
417 *len = (*len + depth_increase).min(MAX_WEIGHT);
418 }
419 }
420
421 // Now we have spare capacity, fill it by shortening some codes
422 let new_kraft_sum: u64 = code_lengths
423 .iter()
424 .filter(|&&l| l > 0)
425 .map(|&l| 1u64 << (new_max_len.saturating_sub(l)) as u32)
426 .sum();
427 let new_target = 1u64 << new_max_len;
428
429 if new_kraft_sum < new_target {
430 Self::fill_kraft_capacity(code_lengths, new_max_len, new_target - new_kraft_sum);
431 }
432 }
433
434 /// Fill unused Kraft capacity by shortening some code lengths.
435 #[allow(dead_code)]
436 fn fill_kraft_capacity(code_lengths: &mut [u8], max_len: u8, mut spare: u64) {
437 // Sort symbols by code length (longest first) to shorten long codes
438 let mut syms: Vec<_> = code_lengths
439 .iter()
440 .enumerate()
441 .filter(|&(_, &l)| l > 1)
442 .map(|(i, &l)| (i, l))
443 .collect();
444 syms.sort_by_key(|&(_, l)| std::cmp::Reverse(l));
445
446 for (idx, old_len) in syms {
447 if spare == 0 {
448 break;
449 }
450 // Shortening by 1: contribution goes from 2^(max_len-old_len) to 2^(max_len-old_len+1)
451 // Increase in usage: 2^(max_len-old_len)
452 let increase = 1u64 << (max_len.saturating_sub(old_len)) as u32;
453 if increase <= spare {
454 code_lengths[idx] = old_len - 1;
455 spare -= increase;
456 }
457 }
458 }
459
460 /// Limit code lengths to ensure they satisfy Kraft inequality.
461 /// Uses a simple algorithm to redistribute long codes.
462 #[allow(dead_code)]
463 fn limit_code_lengths(code_lengths: &mut [u8], max_len: u8) {
464 // Count symbols at each length
465 let mut counts = vec![0u32; max_len as usize + 1];
466 for &len in code_lengths.iter() {
467 if len > 0 && len <= max_len {
468 counts[len as usize] += 1;
469 } else if len > max_len {
470 counts[max_len as usize] += 1;
471 }
472 }
473
474 // Clamp all lengths to max_len
475 for len in code_lengths.iter_mut() {
476 if *len > max_len {
477 *len = max_len;
478 }
479 }
480
481 // Adjust to satisfy Kraft: sum(2^-len) <= 1
482 // Equivalently: sum(2^(max_len - len)) <= 2^max_len
483 loop {
484 let kraft_sum: u64 = counts
485 .iter()
486 .enumerate()
487 .skip(1)
488 .map(|(len, &count)| (count as u64) << (max_len as usize - len))
489 .sum();
490
491 let target = 1u64 << max_len;
492 if kraft_sum <= target {
493 break;
494 }
495
496 // Need to reduce: increase some code lengths
497 // Find the shortest non-empty bucket and move one symbol to next bucket
498 for len in 1..max_len as usize {
499 if counts[len] > 0 {
500 counts[len] -= 1;
501 counts[len + 1] += 1;
502 // Update actual code lengths
503 for code_len in code_lengths.iter_mut() {
504 if *code_len == len as u8 {
505 *code_len = (len + 1) as u8;
506 break;
507 }
508 }
509 break;
510 }
511 }
512 }
513 }
514
515 /// Generate canonical Huffman codes from weights.
516 fn generate_canonical_codes(weights: &[u8], max_bits: u8) -> [HuffmanCode; MAX_SYMBOLS] {
517 let mut codes = [HuffmanCode::default(); MAX_SYMBOLS];
518
519 // Count symbols at each code length
520 let mut bl_count = vec![0u32; max_bits as usize + 2];
521 for &w in weights {
522 if w > 0 {
523 let code_len = (max_bits + 1).saturating_sub(w) as usize;
524 if code_len < bl_count.len() {
525 bl_count[code_len] += 1;
526 }
527 }
528 }
529
530 // Calculate starting codes for each length
531 let mut next_code = vec![0u32; max_bits as usize + 2];
532 let mut code = 0u32;
533 for (bits, next_code_entry) in next_code
534 .iter_mut()
535 .enumerate()
536 .take(max_bits as usize + 1)
537 .skip(1)
538 {
539 code = (code + bl_count.get(bits - 1).copied().unwrap_or(0)) << 1;
540 *next_code_entry = code;
541 }
542
543 // Assign codes to symbols
544 for (symbol, &w) in weights.iter().enumerate() {
545 if w > 0 && symbol < MAX_SYMBOLS {
546 let code_len = (max_bits + 1).saturating_sub(w) as usize;
547 if code_len < next_code.len() {
548 codes[symbol] = HuffmanCode::new(next_code[code_len] as u16, code_len as u8);
549 next_code[code_len] += 1;
550 }
551 }
552 }
553
554 codes
555 }
556
557 /// Encode literals using optimized bit packing.
558 ///
559 /// Uses 64-bit accumulator for efficient byte-aligned writes.
560 /// Optimized with chunked reverse processing and software prefetching
561 /// to maintain cache efficiency despite reverse iteration requirement.
562 ///
563 /// # Performance Optimizations
564 /// - Processes in 64-byte cache-line chunks (reverse chunk order, forward within chunk)
565 /// - Software prefetching brings next chunk into L1 cache ahead of time
566 /// - 64-bit accumulator with branchless 32-bit flushes
567 /// - Unrolled inner loop for better ILP
568 pub fn encode(&self, literals: &[u8]) -> Vec<u8> {
569 if literals.is_empty() {
570 return vec![0x01]; // Just sentinel
571 }
572
573 // Pre-allocate output with better estimate
574 let estimated_bits: usize = literals
575 .iter()
576 .take(256.min(literals.len()))
577 .map(|&b| self.codes[b as usize].num_bits as usize)
578 .sum();
579 let avg_bits = if literals.len() <= 256 {
580 estimated_bits
581 } else {
582 estimated_bits * literals.len() / 256.min(literals.len())
583 };
584 let mut output = Vec::with_capacity(avg_bits.div_ceil(8) + 16);
585
586 // 64-bit accumulator for efficient bit packing
587 let mut accum: u64 = 0;
588 let mut bits_in_accum: u32 = 0;
589
590 // Process in cache-line sized chunks (64 bytes) with prefetching
591 // This maintains cache efficiency despite reverse iteration
592 const CHUNK_SIZE: usize = 64;
593 let len = literals.len();
594 let mut pos = len;
595
596 while pos > 0 {
597 let chunk_start = pos.saturating_sub(CHUNK_SIZE);
598 let chunk_end = pos;
599
600 // Prefetch the NEXT chunk (earlier in memory) into L1 cache
601 // This hides memory latency by fetching ahead
602 #[cfg(target_arch = "x86_64")]
603 if chunk_start >= CHUNK_SIZE {
604 unsafe {
605 use std::arch::x86_64::{_mm_prefetch, _MM_HINT_T0};
606 _mm_prefetch(
607 literals.as_ptr().add(chunk_start - CHUNK_SIZE) as *const i8,
608 _MM_HINT_T0,
609 );
610 }
611 }
612
613 // Process bytes within chunk in reverse order
614 // The chunk is now in L1 cache, so reverse iteration is fast
615 let chunk = &literals[chunk_start..chunk_end];
616
617 // Unroll by 4 for better instruction-level parallelism
618 let chunk_len = chunk.len();
619 let mut i = chunk_len;
620
621 // Handle tail (non-multiple of 4)
622 while i > 0 && !i.is_multiple_of(4) {
623 i -= 1;
624 let byte = chunk[i];
625 let code = &self.codes[byte as usize];
626 let num_bits = code.num_bits as u32;
627
628 if num_bits > 0 {
629 accum |= (code.code as u64) << bits_in_accum;
630 bits_in_accum += num_bits;
631
632 if bits_in_accum >= 32 {
633 output.extend_from_slice(&(accum as u32).to_le_bytes());
634 accum >>= 32;
635 bits_in_accum -= 32;
636 }
637 }
638 }
639
640 // Process 4 bytes at a time (unrolled, branchless)
641 // Novel optimization: Remove all branches in the inner loop.
642 // Since num_bits==0 means the symbol isn't present (code==0, bits==0),
643 // we can unconditionally OR and ADD without changing the result.
644 // This enables better CPU pipelining and SIMD vectorization.
645 while i >= 4 {
646 i -= 4;
647
648 // Load 4 codes (compiler can pipeline these loads)
649 let c0 = self.codes[chunk[i + 3] as usize];
650 let c1 = self.codes[chunk[i + 2] as usize];
651 let c2 = self.codes[chunk[i + 1] as usize];
652 let c3 = self.codes[chunk[i] as usize];
653
654 // Branchless encoding: OR and ADD unconditionally
655 // For valid symbols: adds the code bits
656 // For invalid symbols (num_bits=0): OR 0, ADD 0 - no effect
657 accum |= (c0.code as u64) << bits_in_accum;
658 bits_in_accum += c0.num_bits as u32;
659 accum |= (c1.code as u64) << bits_in_accum;
660 bits_in_accum += c1.num_bits as u32;
661 accum |= (c2.code as u64) << bits_in_accum;
662 bits_in_accum += c2.num_bits as u32;
663 accum |= (c3.code as u64) << bits_in_accum;
664 bits_in_accum += c3.num_bits as u32;
665
666 // Branchless flush: always flush when >= 32 bits
667 // Using conditional move pattern that compilers optimize well
668 if bits_in_accum >= 32 {
669 output.extend_from_slice(&(accum as u32).to_le_bytes());
670 accum >>= 32;
671 bits_in_accum -= 32;
672 }
673 // Second flush for cases where 4 symbols exceed 64 bits total
674 if bits_in_accum >= 32 {
675 output.extend_from_slice(&(accum as u32).to_le_bytes());
676 accum >>= 32;
677 bits_in_accum -= 32;
678 }
679 }
680
681 pos = chunk_start;
682 }
683
684 // Add sentinel bit
685 accum |= 1u64 << bits_in_accum;
686 bits_in_accum += 1;
687
688 // Flush remaining bits (up to 5 bytes: 32 bits max + 1 sentinel)
689 let remaining_bytes = bits_in_accum.div_ceil(8);
690 for _ in 0..remaining_bytes {
691 output.push((accum & 0xFF) as u8);
692 accum >>= 8;
693 }
694
695 output
696 }
697
698 /// Encode literals in batches for better throughput.
699 ///
700 /// Processes 4 symbols at a time when possible.
701 #[allow(dead_code)]
702 pub fn encode_batch(&self, literals: &[u8]) -> Vec<u8> {
703 if literals.len() < 8 {
704 return self.encode(literals);
705 }
706
707 let mut output = Vec::with_capacity(literals.len() / 2 + 8);
708 let mut accum: u64 = 0;
709 let mut bits_in_accum: u32 = 0;
710
711 // Process in reverse, 4 symbols at a time
712 let len = literals.len();
713 let mut i = len;
714
715 // Handle tail (last 1-3 symbols)
716 while i > 0 && !i.is_multiple_of(4) {
717 i -= 1;
718 let code = &self.codes[literals[i] as usize];
719 if code.num_bits > 0 {
720 accum |= (code.code as u64) << bits_in_accum;
721 bits_in_accum += code.num_bits as u32;
722 if bits_in_accum >= 8 {
723 output.push((accum & 0xFF) as u8);
724 accum >>= 8;
725 bits_in_accum -= 8;
726 }
727 }
728 }
729
730 // Process 4 symbols at a time
731 while i >= 4 {
732 i -= 4;
733
734 // Load 4 codes
735 let c0 = &self.codes[literals[i + 3] as usize];
736 let c1 = &self.codes[literals[i + 2] as usize];
737 let c2 = &self.codes[literals[i + 1] as usize];
738 let c3 = &self.codes[literals[i] as usize];
739
740 // Accumulate codes
741 accum |= (c0.code as u64) << bits_in_accum;
742 bits_in_accum += c0.num_bits as u32;
743 accum |= (c1.code as u64) << bits_in_accum;
744 bits_in_accum += c1.num_bits as u32;
745 accum |= (c2.code as u64) << bits_in_accum;
746 bits_in_accum += c2.num_bits as u32;
747 accum |= (c3.code as u64) << bits_in_accum;
748 bits_in_accum += c3.num_bits as u32;
749
750 // Flush complete bytes
751 while bits_in_accum >= 8 {
752 output.push((accum & 0xFF) as u8);
753 accum >>= 8;
754 bits_in_accum -= 8;
755 }
756 }
757
758 // Handle remaining symbols
759 while i > 0 {
760 i -= 1;
761 let code = &self.codes[literals[i] as usize];
762 if code.num_bits > 0 {
763 accum |= (code.code as u64) << bits_in_accum;
764 bits_in_accum += code.num_bits as u32;
765 if bits_in_accum >= 8 {
766 output.push((accum & 0xFF) as u8);
767 accum >>= 8;
768 bits_in_accum -= 8;
769 }
770 }
771 }
772
773 // Add sentinel bit
774 accum |= 1u64 << bits_in_accum;
775 bits_in_accum += 1;
776
777 // Flush remaining
778 if bits_in_accum > 0 {
779 output.push((accum & 0xFF) as u8);
780 }
781
782 output
783 }
784
785 /// Serialize weights in Zstd format (direct or FSE-compressed).
786 ///
787 /// For num_symbols <= 128: Uses direct format
788 /// - header_byte = (num_symbols - 1) + 128
789 /// - Followed by ceil(num_symbols / 2) bytes of 4-bit weights
790 ///
791 /// For num_symbols > 128: Uses FSE-compressed format
792 /// - header_byte < 128 = compressed_size
793 /// - Followed by FSE table and compressed weights
794 pub fn serialize_weights(&self) -> Vec<u8> {
795 // Find last non-zero weight
796 let last_symbol = self
797 .weights
798 .iter()
799 .enumerate()
800 .filter(|&(_, w)| *w > 0)
801 .map(|(i, _)| i)
802 .max()
803 .unwrap_or(0);
804
805 let num_symbols = last_symbol + 1;
806
807 // Calculate direct encoding size
808 let direct_size = 1 + num_symbols.div_ceil(2);
809
810 // Try FSE-compressed weights if beneficial
811 // FSE is typically better when there are many zeros in the weight table
812 // (sparse symbol usage like ASCII text)
813 if num_symbols > 32 {
814 let fse_result = self.serialize_weights_fse(num_symbols);
815 if !fse_result.is_empty() && fse_result.len() < direct_size {
816 return fse_result;
817 }
818 }
819
820 // For >128 symbols, FSE is required
821 if num_symbols > 128 {
822 let fse_result = self.serialize_weights_fse(num_symbols);
823 if !fse_result.is_empty() {
824 return fse_result;
825 }
826 // FSE encoding failed, fall back to empty (caller should use raw block)
827 return Vec::new();
828 }
829
830 // Direct encoding for <= 128 symbols
831 let mut output = Vec::with_capacity(direct_size);
832
833 if num_symbols > 0 {
834 output.push(((num_symbols - 1) + 128) as u8);
835
836 // Pack weights as 4-bit nibbles
837 // Our decoder expects: Weight[i] in high nibble, Weight[i+1] in low nibble
838 for i in (0..num_symbols).step_by(2) {
839 let w1 = self.weights.get(i).copied().unwrap_or(0);
840 let w2 = self.weights.get(i + 1).copied().unwrap_or(0);
841 output.push((w1 << 4) | (w2 & 0x0F));
842 }
843 }
844
845 output
846 }
847
848 /// Serialize weights using FSE compression for >128 symbols.
849 ///
850 /// Per RFC 8878 Section 4.2.1.1:
851 /// - header_byte < 128 indicates FSE-compressed weights
852 /// - header_byte value is the compressed size in bytes
853 /// - Weights are encoded using an FSE table with max_symbol = 12 (weights 0-12)
854 ///
855 /// The FSE bitstream format for Huffman weights:
856 /// 1. FSE table header (accuracy_log + probabilities)
857 /// 2. Compressed bitstream read in reverse (from end with sentinel)
858 /// - Initial decoder state (accuracy_log bits, MSB-first from end)
859 /// - Encoded symbols' bits for state transitions
860 fn serialize_weights_fse(&self, num_symbols: usize) -> Vec<u8> {
861 // Count frequency of each weight value (weights are 0-11)
862 let mut weight_freq = [0i16; 13]; // 0-12 possible weight values
863 for i in 0..num_symbols {
864 let w = self.weights.get(i).copied().unwrap_or(0) as usize;
865 if w <= 12 {
866 weight_freq[w] += 1;
867 }
868 }
869
870 // Choose accuracy_log (6 is typical for Huffman weights per RFC 8878)
871 const WEIGHT_ACCURACY_LOG: u8 = 6;
872 let table_size = 1i16 << WEIGHT_ACCURACY_LOG;
873
874 // Normalize frequencies to sum to table_size
875 let total: i16 = weight_freq.iter().sum();
876 if total == 0 {
877 return Vec::new(); // No weights to encode
878 }
879
880 let mut normalized = [0i16; 13];
881 let mut remaining = table_size;
882
883 // First pass: assign proportional counts
884 for (i, &freq) in weight_freq.iter().enumerate() {
885 if freq > 0 {
886 let norm = ((freq as i32 * table_size as i32) / total as i32).max(1) as i16;
887 normalized[i] = norm;
888 remaining -= norm;
889 }
890 }
891
892 // Distribute remaining capacity to largest frequencies
893 while remaining > 0 {
894 let mut best_idx = 0;
895 let mut best_freq = 0;
896 for (i, &freq) in weight_freq.iter().enumerate() {
897 if freq > best_freq && normalized[i] > 0 {
898 best_freq = freq;
899 best_idx = i;
900 }
901 }
902 if best_freq == 0 {
903 break;
904 }
905 normalized[best_idx] += 1;
906 remaining -= 1;
907 }
908
909 // Handle over-allocation (can happen due to rounding)
910 while remaining < 0 {
911 let mut best_idx = 0;
912 let mut best_norm = 0;
913 for (i, &norm) in normalized.iter().enumerate() {
914 if norm > 1 && norm > best_norm {
915 best_norm = norm;
916 best_idx = i;
917 }
918 }
919 if best_norm <= 1 {
920 break;
921 }
922 normalized[best_idx] -= 1;
923 remaining += 1;
924 }
925
926 // Build FSE table from normalized frequencies
927 let fse_table = match FseTable::build(&normalized, WEIGHT_ACCURACY_LOG, 12) {
928 Ok(t) => t,
929 Err(_) => return Vec::new(), // Failed to build table
930 };
931
932 // Serialize FSE table header
933 let table_header = Self::serialize_fse_table_header(&normalized, WEIGHT_ACCURACY_LOG);
934
935 // For FSE encoding, we use a simulation-based approach:
936 // 1. Find the sequence of decoder states that produces our weight sequence
937 // 2. Work backwards to compute the bits needed for each transition
938 //
939 // The decoder works as:
940 // state → (symbol, baseline, num_bits)
941 // next_state = baseline + read_bits(num_bits)
942 //
943 // So for encoding, we need to find states s0, s1, ... such that:
944 // table[s0].symbol = weight[0]
945 // table[s1].symbol = weight[1], and s1 = table[s0].baseline + bits0
946 // etc.
947
948 // Collect weights to encode
949 let weights_to_encode: Vec<u8> = (0..num_symbols)
950 .map(|i| self.weights.get(i).copied().unwrap_or(0))
951 .collect();
952
953 // Find valid decoder state sequence
954 // For each weight value, find all states that decode to it
955 let mut states_for_symbol: [Vec<usize>; 13] = Default::default();
956 for state in 0..fse_table.size() {
957 let entry = fse_table.decode(state);
958 if (entry.symbol as usize) < 13 {
959 states_for_symbol[entry.symbol as usize].push(state);
960 }
961 }
962
963 // Check if all weight values have at least one state
964 for &w in &weights_to_encode {
965 if states_for_symbol[w as usize].is_empty() {
966 return Vec::new(); // Can't encode this weight
967 }
968 }
969
970 // Use greedy approach: for each symbol, pick a state that works
971 // and compute the bits needed for the transition from the previous state
972 let mut state_sequence = Vec::with_capacity(num_symbols);
973 let mut bits_sequence: Vec<(u32, u8)> = Vec::with_capacity(num_symbols);
974
975 // First state: pick any state for the first weight
976 let first_weight = weights_to_encode[0] as usize;
977 let first_state = states_for_symbol[first_weight][0];
978 state_sequence.push(first_state);
979
980 // For each subsequent weight, find a state and compute transition bits
981 for i in 1..num_symbols {
982 let prev_state = state_sequence[i - 1];
983 let prev_entry = fse_table.decode(prev_state);
984 let target_weight = weights_to_encode[i] as usize;
985
986 // We need: next_state = baseline + bits
987 // where table[next_state].symbol = target_weight
988 // and bits < (1 << num_bits)
989 let baseline = prev_entry.baseline as usize;
990 let num_bits = prev_entry.num_bits;
991 let max_bits_value = 1usize << num_bits;
992
993 // Find a state for target_weight that can be reached
994 let mut found = false;
995 for &candidate_state in &states_for_symbol[target_weight] {
996 if candidate_state >= baseline && candidate_state < baseline + max_bits_value {
997 let bits = (candidate_state - baseline) as u32;
998 bits_sequence.push((bits, num_bits));
999 state_sequence.push(candidate_state);
1000 found = true;
1001 break;
1002 }
1003 }
1004
1005 if !found {
1006 // Try wrapping around by using a different previous state
1007 // This is a simplification - full implementation would backtrack
1008 return Vec::new(); // Can't find valid encoding path
1009 }
1010 }
1011
1012 // Now build the bitstream
1013 // The decoder reads:
1014 // 1. Initial state (accuracy_log bits) - this is state_sequence[0]
1015 // 2. For each symbol after the first, read bits for next state
1016 // 3. Final symbol is decoded from current state without reading more bits
1017 //
1018 // The bitstream is read in reverse (MSB-first from end).
1019 // So we write: [transition bits...][initial_state][sentinel]
1020 // And the bytes need to be arranged so that reversed reading works.
1021
1022 // Build forward bitstream (we'll handle reversal through the writer)
1023 let mut bit_writer = FseBitWriter::new();
1024
1025 // Write transition bits in order (they'll be read in reverse)
1026 // But wait - the reversed reader reads from the end, so the LAST bits
1027 // written should be read FIRST (as initial state).
1028 //
1029 // We need:
1030 // - Write initial_state last (so it's at the end, read first)
1031 // - Write transition bits before that
1032 //
1033 // Current approach: write bits in reverse order of how decoder reads
1034 // Decoder reads: init_state, then bits for s1, bits for s2, ...
1035 // We write: bits for s_{n-1}, bits for s_{n-2}, ..., bits for s1, init_state
1036
1037 // Write transition bits in reverse order
1038 for i in (0..bits_sequence.len()).rev() {
1039 let (bits, num_bits) = bits_sequence[i];
1040 bit_writer.write_bits(bits, num_bits);
1041 }
1042
1043 // Write initial state (will be read first by decoder)
1044 bit_writer.write_bits(state_sequence[0] as u32, WEIGHT_ACCURACY_LOG);
1045
1046 // Finish bitstream (adds sentinel)
1047 let mut compressed_stream = bit_writer.finish();
1048
1049 // The FseBitWriter produces bits in LSB-first order within bytes,
1050 // but the reversed reader reads MSB-first. We need to bit-reverse each byte.
1051 for byte in &mut compressed_stream {
1052 *byte = byte.reverse_bits();
1053 }
1054
1055 // Combine: FSE table header + compressed stream
1056 let total_compressed_size = table_header.len() + compressed_stream.len();
1057
1058 // Check if compressed size fits in header byte (< 128)
1059 if total_compressed_size >= 128 {
1060 return Vec::new(); // Too large for FSE format
1061 }
1062
1063 // Build final output
1064 let mut output = Vec::with_capacity(1 + total_compressed_size);
1065 output.push(total_compressed_size as u8); // header < 128 = FSE compressed
1066 output.extend_from_slice(&table_header);
1067 output.extend_from_slice(&compressed_stream);
1068
1069 output
1070 }
1071
1072 /// Serialize FSE table header for Huffman weights.
1073 ///
1074 /// Format: 4-bit accuracy_log + variable-length probabilities
1075 #[allow(dead_code)]
1076 fn serialize_fse_table_header(normalized: &[i16; 13], accuracy_log: u8) -> Vec<u8> {
1077 let mut output = Vec::with_capacity(16);
1078 let mut bit_pos = 0u32;
1079 let mut accum = 0u64;
1080
1081 // Write accuracy_log - 5 (4 bits)
1082 let acc_val = (accuracy_log.saturating_sub(5)) as u64;
1083 accum |= acc_val << bit_pos;
1084 bit_pos += 4;
1085
1086 // Write probabilities using variable-length encoding
1087 let table_size = 1i32 << accuracy_log;
1088 let mut remaining = table_size;
1089
1090 for &prob in normalized.iter() {
1091 if remaining <= 0 {
1092 break;
1093 }
1094
1095 // Calculate bits needed to encode this probability
1096 let max_bits = 32 - (remaining + 1).leading_zeros();
1097 let threshold = (1i32 << max_bits) - 1 - remaining;
1098
1099 // Encode probability
1100 let prob_val = if prob == -1 { 0 } else { prob as i32 };
1101
1102 if prob_val < threshold {
1103 // Small value: use max_bits - 1 bits
1104 accum |= (prob_val as u64) << bit_pos;
1105 bit_pos += max_bits - 1;
1106 } else {
1107 // Large value: use max_bits bits
1108 let large = prob_val + threshold;
1109 accum |= (large as u64) << bit_pos;
1110 bit_pos += max_bits;
1111 }
1112
1113 // Flush complete bytes
1114 while bit_pos >= 8 {
1115 output.push((accum & 0xFF) as u8);
1116 accum >>= 8;
1117 bit_pos -= 8;
1118 }
1119
1120 // Update remaining
1121 if prob == -1 {
1122 remaining -= 1;
1123 } else {
1124 remaining -= prob as i32;
1125 }
1126 }
1127
1128 // Flush remaining bits
1129 if bit_pos > 0 {
1130 output.push((accum & 0xFF) as u8);
1131 }
1132
1133 output
1134 }
1135
1136 /// Get maximum code length.
1137 #[inline]
1138 pub fn max_bits(&self) -> u8 {
1139 self.max_bits
1140 }
1141
1142 /// Get number of symbols with codes.
1143 #[inline]
1144 pub fn num_symbols(&self) -> usize {
1145 self.num_symbols
1146 }
1147
1148 /// Estimate compressed size.
1149 pub fn estimate_size(&self, literals: &[u8]) -> usize {
1150 let mut total_bits: usize = 0;
1151 for &byte in literals {
1152 total_bits += self.codes[byte as usize].num_bits as usize;
1153 }
1154 // Weight table size depends on last_symbol (highest symbol index), not unique count
1155 // Direct encoding uses (last_symbol + 1) symbols in the table
1156 let num_table_symbols = self.last_symbol + 1;
1157 let weight_table_size = 1 + num_table_symbols.div_ceil(2);
1158 total_bits.div_ceil(8) + weight_table_size
1159 }
1160
1161 /// Get code for a symbol (for testing).
1162 #[cfg(test)]
1163 pub fn get_codes(&self) -> &[HuffmanCode; MAX_SYMBOLS] {
1164 &self.codes
1165 }
1166}
1167
1168// =============================================================================
1169// Tests
1170// =============================================================================
1171
1172#[cfg(test)]
1173mod tests {
1174 use super::*;
1175
1176 #[test]
1177 fn test_build_simple() {
1178 let mut data = Vec::new();
1179 for _ in 0..100 {
1180 data.push(b'a');
1181 }
1182 for _ in 0..50 {
1183 data.push(b'b');
1184 }
1185 for _ in 0..25 {
1186 data.push(b'c');
1187 }
1188
1189 let encoder = HuffmanEncoder::build(&data);
1190 assert!(encoder.is_some());
1191
1192 let encoder = encoder.unwrap();
1193 assert!(encoder.num_symbols() >= 3);
1194 }
1195
1196 #[test]
1197 fn test_build_too_small() {
1198 let data = b"small";
1199 let encoder = HuffmanEncoder::build(data);
1200 assert!(encoder.is_none());
1201 }
1202
1203 #[test]
1204 fn test_encode_simple() {
1205 let mut data = Vec::new();
1206 for _ in 0..100 {
1207 data.push(b'a');
1208 }
1209 for _ in 0..50 {
1210 data.push(b'b');
1211 }
1212
1213 let encoder = HuffmanEncoder::build(&data);
1214 if let Some(enc) = encoder {
1215 let compressed = enc.encode(&data);
1216 assert!(compressed.len() < data.len());
1217 }
1218 }
1219
1220 #[test]
1221 fn test_encode_batch() {
1222 let mut data = Vec::new();
1223 for _ in 0..100 {
1224 data.push(b'a');
1225 }
1226 for _ in 0..50 {
1227 data.push(b'b');
1228 }
1229 for _ in 0..25 {
1230 data.push(b'c');
1231 }
1232
1233 let encoder = HuffmanEncoder::build(&data);
1234 if let Some(enc) = encoder {
1235 let regular = enc.encode(&data);
1236 let batch = enc.encode_batch(&data);
1237
1238 // Both should produce valid compressed data
1239 assert!(!regular.is_empty());
1240 assert!(!batch.is_empty());
1241 }
1242 }
1243
1244 #[test]
1245 fn test_serialize_weights() {
1246 let mut data = Vec::new();
1247 for _ in 0..100 {
1248 data.push(b'a');
1249 }
1250 for _ in 0..50 {
1251 data.push(b'b');
1252 }
1253
1254 let encoder = HuffmanEncoder::build(&data);
1255 if let Some(enc) = encoder {
1256 let weights = enc.serialize_weights();
1257 assert!(!weights.is_empty());
1258 assert!(weights[0] >= 128); // Direct format
1259 }
1260 }
1261
1262 #[test]
1263 fn test_estimate_size() {
1264 let mut data = Vec::new();
1265 for _ in 0..100 {
1266 data.push(b'a');
1267 }
1268 for _ in 0..50 {
1269 data.push(b'b');
1270 }
1271
1272 let encoder = HuffmanEncoder::build(&data);
1273 if let Some(enc) = encoder {
1274 let estimated = enc.estimate_size(&data);
1275 let actual = enc.encode(&data).len() + enc.serialize_weights().len();
1276 assert!(estimated <= actual + 10);
1277 }
1278 }
1279
1280 #[test]
1281 fn test_frequency_counting() {
1282 let data = vec![0u8, 1, 2, 0, 1, 0, 0, 0, 1, 2, 3];
1283 let freq = HuffmanEncoder::count_frequencies(&data);
1284
1285 assert_eq!(freq[0], 5);
1286 assert_eq!(freq[1], 3);
1287 assert_eq!(freq[2], 2);
1288 assert_eq!(freq[3], 1);
1289 }
1290
1291 #[test]
1292 fn test_huffman_code_alignment() {
1293 // Verify HuffmanCode is properly aligned
1294 assert_eq!(std::mem::size_of::<HuffmanCode>(), 4);
1295 assert_eq!(std::mem::align_of::<HuffmanCode>(), 4);
1296 }
1297
1298 #[test]
1299 fn test_many_symbols_uses_direct_encoding() {
1300 // Test with many unique symbols (but <= 128)
1301 // Create data with 100 unique symbols
1302 let mut data = Vec::new();
1303 for sym in 0..100u8 {
1304 for _ in 0..(100 - sym as usize).max(1) {
1305 data.push(sym);
1306 }
1307 }
1308
1309 let encoder = HuffmanEncoder::build(&data);
1310 assert!(encoder.is_some(), "Should build encoder for 100 symbols");
1311
1312 if let Some(enc) = encoder {
1313 let weights = enc.serialize_weights();
1314 assert!(!weights.is_empty(), "Should serialize weights");
1315 // Should use direct encoding (header >= 128)
1316 assert!(
1317 weights[0] >= 128,
1318 "Should use direct format for <= 128 symbols"
1319 );
1320 }
1321 }
1322
1323 #[test]
1324 fn test_fse_table_header_serialization() {
1325 // Test the FSE table header serialization format
1326 let normalized = [32i16, 16, 8, 4, 2, 1, 1, 0, 0, 0, 0, 0, 0];
1327 let header = HuffmanEncoder::serialize_fse_table_header(&normalized, 6);
1328
1329 // Header should not be empty
1330 assert!(!header.is_empty());
1331
1332 // First 4 bits should be accuracy_log - 5 = 1
1333 assert_eq!(header[0] & 0x0F, 1);
1334 }
1335}