Skip to main content

jxl_encoder/entropy_coding/
ans.rs

1// Copyright (c) Imazen LLC and the JPEG XL Project Authors.
2// Algorithms and constants derived from libjxl (BSD-3-Clause).
3// Licensed under AGPL-3.0-or-later. Commercial licenses at https://www.imazen.io/pricing
4
5//! ANS (Asymmetric Numeral Systems) encoder.
6//!
7//! ANS is an entropy coding method used in JPEG XL for efficient symbol
8//! compression. This module implements the rANS (range ANS) variant.
9
10use crate::bit_writer::BitWriter;
11use crate::error::{Error, Result};
12
13/// ANS table size (2^12 = 4096).
14pub const ANS_LOG_TAB_SIZE: u32 = 12;
15pub const ANS_TAB_SIZE: u32 = 1 << ANS_LOG_TAB_SIZE;
16pub const ANS_TAB_MASK: u32 = ANS_TAB_SIZE - 1;
17
18/// Maximum alphabet size for ANS.
19pub const ANS_MAX_ALPHABET_SIZE: usize = 256;
20
21/// Initial state marker.
22pub const ANS_SIGNATURE: u32 = 0x13;
23
24/// RLE marker symbol in logcount prefix code.
25#[allow(dead_code)]
26const RLE_MARKER_SYM: u8 = 13;
27
28/// Prefix code table for encoding log-frequency values (0-13).
29/// Format: (nbits, code_lsb) - the code to write for each symbol.
30/// This is the inverse of the decoder's lookup table in jxl-rs.
31const LOGCOUNT_PREFIX_CODE: [(u8, u8); 14] = [
32    (5, 0b10001),   // 0: freq=0 (but we use 0 for zero, not logcount 0)
33    (4, 0b1011),    // 1: logcount=1, freq=1
34    (4, 0b1111),    // 2: logcount=2, freq in [2,3]
35    (4, 0b0011),    // 3: logcount=3, freq in [4,7]
36    (4, 0b1001),    // 4: logcount=4, freq in [8,15]
37    (4, 0b0111),    // 5: logcount=5, freq in [16,31]
38    (3, 0b100),     // 6: logcount=6, freq in [32,63]
39    (3, 0b010),     // 7: logcount=7, freq in [64,127]
40    (3, 0b101),     // 8: logcount=8, freq in [128,255]
41    (3, 0b110),     // 9: logcount=9, freq in [256,511]
42    (3, 0b000),     // 10: logcount=10, freq in [512,1023]
43    (6, 0b100001),  // 11: logcount=11, freq in [1024,2047]
44    (7, 0b0000001), // 12: logcount=12, freq in [2048,4095]
45    (7, 0b1000001), // 13: RLE marker
46];
47
48/// Precision for reciprocal multiplication (avoids division).
49const RECIPROCAL_PRECISION: u32 = 44;
50
51/// Symbol information for ANS encoding.
52#[derive(Debug, Clone)]
53pub struct AnsEncSymbolInfo {
54    /// Normalized frequency (1 to ANS_TAB_SIZE).
55    pub freq: u16,
56    /// Reciprocal of frequency for fast division: ceil((1 << 44) / freq).
57    pub ifreq: u64,
58    /// Maps remainder values to table offsets.
59    pub reverse_map: Vec<u16>,
60}
61
62impl AnsEncSymbolInfo {
63    /// Creates symbol info for a given frequency.
64    pub fn new(freq: u16) -> Self {
65        let ifreq = if freq > 0 {
66            (1u64 << RECIPROCAL_PRECISION).div_ceil(freq as u64)
67        } else {
68            0
69        };
70
71        Self {
72            freq,
73            ifreq,
74            reverse_map: Vec::new(), // Allocated later in build_reverse_maps
75        }
76    }
77}
78
79/// ANS encoder state.
80pub struct AnsEncoder {
81    /// ANS state (normalized to range).
82    state: u32,
83    /// Accumulated output bits (stored reversed).
84    bits: Vec<(u32, u8)>, // (bits, nbits)
85}
86
87impl AnsEncoder {
88    /// Creates a new ANS encoder.
89    pub fn new() -> Self {
90        Self {
91            state: ANS_SIGNATURE << 16,
92            bits: Vec::new(),
93        }
94    }
95
96    /// Creates a new ANS encoder with pre-allocated capacity for `num_tokens` tokens.
97    pub fn with_capacity(num_tokens: usize) -> Self {
98        Self {
99            state: ANS_SIGNATURE << 16,
100            bits: Vec::with_capacity(num_tokens * 2), // ~2 entries per token (symbol + extra bits)
101        }
102    }
103
104    /// Encodes a single symbol using precomputed symbol info.
105    ///
106    /// Returns the bits that should be output (if any).
107    #[inline]
108    pub fn put_symbol(&mut self, info: &AnsEncSymbolInfo) {
109        let freq = info.freq as u32;
110
111        // Renormalization: if state is too large, emit 16 bits
112        if (self.state >> (32 - ANS_LOG_TAB_SIZE)) >= freq {
113            self.bits.push((self.state & 0xFFFF, 16));
114            self.state >>= 16;
115        }
116
117        // State update using multiplication-by-reciprocal
118        // v = state / freq (approximately)
119        let v = ((self.state as u64 * info.ifreq) >> RECIPROCAL_PRECISION) as u32;
120        let remainder = self.state - v * freq;
121
122        // Look up offset in reverse map
123        let offset = info.reverse_map[remainder as usize] as u32;
124
125        // Update state
126        self.state = (v << ANS_LOG_TAB_SIZE) + offset;
127    }
128
129    /// Pushes extra bits into the encoder's output buffer.
130    ///
131    /// Used for HybridUint extra bits that are interleaved with ANS symbols.
132    /// These bits are stored in the same reversed buffer and will be emitted
133    /// in proper order during finalize().
134    #[inline]
135    pub fn push_bits(&mut self, bits: u32, nbits: u8) {
136        if nbits > 0 {
137            self.bits.push((bits, nbits));
138        }
139    }
140
141    /// Finalizes encoding and writes to a BitWriter.
142    ///
143    /// Writes the final state followed by all accumulated bits in reverse order.
144    pub fn finalize(self, writer: &mut BitWriter) -> Result<()> {
145        // Debug: show final state
146        #[cfg(feature = "debug-tokens")]
147        eprintln!(
148            "ANS finalize: state=0x{:08x}, {} bit chunks",
149            self.state,
150            self.bits.len()
151        );
152
153        // Write final state (32 bits)
154        writer.write(32, self.state as u64)?;
155
156        // Write accumulated bits in reverse order
157        for &(bits, nbits) in self.bits.iter().rev() {
158            writer.write(nbits as usize, bits as u64)?;
159        }
160
161        Ok(())
162    }
163
164    /// Returns the current state.
165    pub fn state(&self) -> u32 {
166        self.state
167    }
168}
169
170impl Default for AnsEncoder {
171    fn default() -> Self {
172        Self::new()
173    }
174}
175
176/// A complete ANS distribution with encoding info for all symbols.
177#[derive(Debug, Clone)]
178pub struct AnsDistribution {
179    /// Symbol encoding information.
180    pub symbols: Vec<AnsEncSymbolInfo>,
181    /// Log2 of distribution size (typically 12).
182    pub log_alpha_size: u32,
183    /// Total of normalized frequencies (should be ANS_TAB_SIZE).
184    pub total: u32,
185}
186
187impl AnsDistribution {
188    /// Creates a distribution from raw frequencies.
189    ///
190    /// Normalizes frequencies to sum to ANS_TAB_SIZE.
191    pub fn from_frequencies(freqs: &[u32]) -> Result<Self> {
192        if freqs.is_empty() {
193            return Err(Error::InvalidHistogram("empty distribution".to_string()));
194        }
195
196        let total_count: u64 = freqs.iter().map(|&f| f as u64).sum();
197        if total_count == 0 {
198            return Err(Error::InvalidHistogram("all zero frequencies".to_string()));
199        }
200
201        // Normalize frequencies to sum to ANS_TAB_SIZE
202        let mut normalized: Vec<u16> = Vec::with_capacity(freqs.len());
203        let mut running_total: u32 = 0;
204
205        for &freq in freqs.iter() {
206            let normalized_freq = if freq == 0 {
207                0
208            } else {
209                // Scale to ANS_TAB_SIZE, ensuring at least 1 for non-zero
210                ((freq as u64 * ANS_TAB_SIZE as u64) / total_count).max(1) as u16
211            };
212            normalized.push(normalized_freq);
213            running_total += normalized_freq as u32;
214        }
215
216        // Adjust to exactly sum to ANS_TAB_SIZE
217        let diff = running_total as i32 - ANS_TAB_SIZE as i32;
218        if diff != 0 {
219            // Find largest frequency and adjust it
220            if let Some((max_idx, _)) = normalized
221                .iter()
222                .enumerate()
223                .filter(|&(_, &f)| f > 0)
224                .max_by_key(|&(_, &f)| f)
225            {
226                let new_val = (normalized[max_idx] as i32 - diff).max(1) as u16;
227                normalized[max_idx] = new_val;
228            }
229        }
230
231        // Build symbol info with reverse maps
232        let mut symbols: Vec<AnsEncSymbolInfo> = normalized
233            .iter()
234            .map(|&f| AnsEncSymbolInfo::new(f))
235            .collect();
236
237        // Build reverse map (alias table)
238        Self::build_reverse_maps(&mut symbols)?;
239
240        Ok(Self {
241            symbols,
242            log_alpha_size: ANS_LOG_TAB_SIZE,
243            total: ANS_TAB_SIZE,
244        })
245    }
246
247    /// Creates a flat (uniform) distribution.
248    pub fn flat(alphabet_size: usize) -> Result<Self> {
249        if alphabet_size == 0 || alphabet_size > ANS_TAB_SIZE as usize {
250            return Err(Error::InvalidHistogram(format!(
251                "invalid alphabet size: {}",
252                alphabet_size
253            )));
254        }
255
256        let base_freq = ANS_TAB_SIZE as usize / alphabet_size;
257        let remainder = ANS_TAB_SIZE as usize % alphabet_size;
258
259        let mut freqs = vec![base_freq as u32; alphabet_size];
260        for freq in freqs.iter_mut().take(remainder) {
261            *freq += 1;
262        }
263
264        Self::from_frequencies(&freqs)
265    }
266
267    /// Creates a distribution from pre-normalized counts.
268    ///
269    /// The counts must already sum to ANS_TAB_SIZE (4096). This is used
270    /// when building distributions from ANSEncodingHistogram which has
271    /// already done the normalization.
272    pub fn from_normalized_counts(counts: &[i32]) -> Result<Self> {
273        if counts.is_empty() {
274            return Err(Error::InvalidHistogram("empty distribution".to_string()));
275        }
276
277        // Verify sum
278        let total: i32 = counts.iter().sum();
279        if total != ANS_TAB_SIZE as i32 {
280            return Err(Error::InvalidHistogram(format!(
281                "normalized counts sum to {} instead of {}",
282                total, ANS_TAB_SIZE
283            )));
284        }
285
286        // Build symbol info
287        let mut symbols: Vec<AnsEncSymbolInfo> = counts
288            .iter()
289            .map(|&c| AnsEncSymbolInfo::new(c.max(0) as u16))
290            .collect();
291
292        // Build reverse maps
293        Self::build_reverse_maps(&mut symbols)?;
294
295        Ok(Self {
296            symbols,
297            log_alpha_size: ANS_LOG_TAB_SIZE,
298            total: ANS_TAB_SIZE,
299        })
300    }
301
302    /// Builds reverse maps for all symbols using the alias table method.
303    ///
304    /// The decoder uses an alias table with buckets. Each idx in [0, 4096) maps to some
305    /// symbol and an offset within that symbol's range. The encoder needs to know,
306    /// for each symbol s and remainder r in [0, freq[s]), what idx to output.
307    ///
308    /// This exactly mirrors the decoder's build_alias_map and read methods.
309    fn build_reverse_maps(symbols: &mut [AnsEncSymbolInfo]) -> Result<()> {
310        let alphabet_size = symbols.len();
311        if alphabet_size == 0 {
312            return Ok(());
313        }
314
315        // Verify frequencies sum to ANS_TAB_SIZE
316        let total: u32 = symbols.iter().map(|s| s.freq as u32).sum();
317        if total != ANS_TAB_SIZE {
318            return Err(Error::InvalidHistogram(format!(
319                "frequencies sum to {} instead of {}",
320                total, ANS_TAB_SIZE
321            )));
322        }
323
324        // Special case: single-symbol distribution
325        // jxl-rs uses a simplified alias table where offset = idx for all positions.
326        // This means reverse_map[r] = r (identity mapping) for the single symbol.
327        if let Some(single_sym_idx) = symbols.iter().position(|s| s.freq == ANS_TAB_SIZE as u16) {
328            // Clear all reverse maps
329            for sym in symbols.iter_mut() {
330                sym.reverse_map.clear();
331            }
332            // Set identity mapping for the single symbol
333            let map = &mut symbols[single_sym_idx].reverse_map;
334            map.resize(ANS_TAB_SIZE as usize, 0);
335            for (i, v) in map.iter_mut().enumerate() {
336                *v = i as u16;
337            }
338            return Ok(());
339        }
340
341        // Build the alias table exactly like jxl-rs decoder does.
342        // Standard JXL ANS uses log_alpha_size=6 (64 buckets). We only increase
343        // it when the alphabet is too large for 64 buckets.
344        let log_alpha_size = if alphabet_size <= 64 {
345            6 // Standard value, matches decoder expectations
346        } else {
347            let min_bits = (alphabet_size - 1).ilog2() as usize + 1;
348            min_bits.min(ANS_LOG_TAB_SIZE as usize)
349        };
350        let table_size = 1usize << log_alpha_size;
351        let log_bucket_size = ANS_LOG_TAB_SIZE as usize - log_alpha_size;
352        let bucket_size = 1u16 << log_bucket_size;
353
354        // Working bucket structure matching jxl-rs
355        #[derive(Clone)]
356        #[allow(dead_code)]
357        struct WorkingBucket {
358            dist: u16,         // Frequency of primary symbol
359            alias_symbol: u16, // Alias symbol (used when pos >= cutoff)
360            alias_offset: u16, // Offset for alias symbol
361            alias_cutoff: u16, // Positions [0, cutoff) use primary, [cutoff, bucket_size) use alias
362        }
363
364        let mut buckets: Vec<WorkingBucket> = (0..table_size)
365            .map(|i| {
366                let dist = if i < alphabet_size {
367                    symbols[i].freq
368                } else {
369                    0
370                };
371                WorkingBucket {
372                    dist,
373                    alias_symbol: if i < alphabet_size { i as u16 } else { 0 },
374                    alias_offset: 0,
375                    alias_cutoff: dist,
376                }
377            })
378            .collect();
379
380        // Separate into underfull and overfull
381        let mut underfull: Vec<usize> = Vec::with_capacity(table_size);
382        let mut overfull: Vec<usize> = Vec::with_capacity(table_size);
383        for (i, bucket) in buckets.iter().enumerate() {
384            if bucket.alias_cutoff < bucket_size {
385                underfull.push(i);
386            } else if bucket.alias_cutoff > bucket_size {
387                overfull.push(i);
388            }
389        }
390
391        // Alias redistribution - exactly matching jxl-rs
392        while let (Some(o), Some(u)) = (overfull.pop(), underfull.pop()) {
393            let by = bucket_size - buckets[u].alias_cutoff;
394            buckets[o].alias_cutoff -= by;
395            buckets[u].alias_symbol = o as u16;
396            buckets[u].alias_offset = buckets[o].alias_cutoff;
397
398            match buckets[o].alias_cutoff.cmp(&bucket_size) {
399                std::cmp::Ordering::Less => underfull.push(o),
400                std::cmp::Ordering::Greater => overfull.push(o),
401                std::cmp::Ordering::Equal => {}
402            }
403        }
404
405        // Pre-allocate reverse maps with exact sizes (offset is 0..freq-1)
406        for sym in symbols.iter_mut() {
407            sym.reverse_map.clear();
408            sym.reverse_map.resize(sym.freq as usize, 0);
409        }
410
411        // For each idx in [0, 4096), simulate the decoder to find which symbol it decodes to
412        // and what offset within that symbol's range. Write directly into reverse_map[offset].
413        for idx in 0..ANS_TAB_SIZE {
414            let bucket_idx = (idx >> log_bucket_size) as usize;
415            let pos = (idx as u16) & (bucket_size - 1);
416
417            let bucket = &buckets[bucket_idx.min(table_size - 1)];
418            let alias_cutoff = bucket.alias_cutoff;
419
420            let (symbol, offset) = if pos < alias_cutoff {
421                (bucket_idx, pos)
422            } else {
423                let alias_sym = bucket.alias_symbol as usize;
424                let offset = bucket.alias_offset - alias_cutoff + pos;
425                (alias_sym, offset)
426            };
427
428            if symbol < alphabet_size {
429                symbols[symbol].reverse_map[offset as usize] = idx as u16;
430            }
431        }
432
433        Ok(())
434    }
435
436    /// Returns the number of symbols in this distribution.
437    pub fn alphabet_size(&self) -> usize {
438        self.symbols.len()
439    }
440
441    /// Gets the encoding info for a symbol.
442    pub fn get(&self, symbol: usize) -> Option<&AnsEncSymbolInfo> {
443        self.symbols.get(symbol)
444    }
445
446    /// Writes this distribution to a BitWriter.
447    pub fn write(&self, writer: &mut BitWriter) -> Result<()> {
448        // Check if this is a flat distribution
449        let is_flat = self.is_flat();
450
451        writer.write(1, 0)?; // Non-small tree marker
452        writer.write(1, u64::from(is_flat))?;
453
454        if is_flat {
455            // Flat distribution: just encode alphabet size
456            write_var_len_uint8(writer, (self.alphabet_size() - 1) as u8)?;
457        } else {
458            // General distribution encoding
459            // For now, use the simplest encoding (shift = 0, meaning power-of-2 counts)
460            self.write_general(writer)?;
461        }
462
463        Ok(())
464    }
465
466    /// Checks if this is a flat (uniform) distribution.
467    fn is_flat(&self) -> bool {
468        let first_freq = self.symbols.first().map(|s| s.freq).unwrap_or(0);
469        if first_freq == 0 {
470            return false;
471        }
472        self.symbols
473            .iter()
474            .all(|s| s.freq == first_freq || s.freq == first_freq - 1)
475    }
476
477    /// Writes a general (non-flat) distribution.
478    fn write_general(&self, writer: &mut BitWriter) -> Result<()> {
479        // Encode shift (we use shift=12 for maximum precision)
480        let method: u64 = 13; // shift + 1
481        let upper_bound_log = 4; // floor_log2(13)
482        let log = floor_log2(method as u32);
483
484        // Write unary prefix
485        writer.write(log as usize, (1u64 << log) - 1)?;
486        if log != upper_bound_log {
487            writer.write(1, 0)?;
488        }
489        // Write value suffix
490        writer.write(log as usize, ((1u64 << log) - 1) & method)?;
491
492        // Encode alphabet size
493        write_var_len_uint8(writer, (self.alphabet_size() - 3) as u8)?;
494
495        // For a simple implementation, encode each frequency directly
496        // Full implementation would use Huffman for bit-widths + RLE
497        for sym in &self.symbols {
498            // Encode frequency using a simple variable-length code
499            let freq = sym.freq;
500            if freq == 0 {
501                writer.write(1, 0)?;
502            } else {
503                writer.write(1, 1)?;
504                let bits = 16 - freq.leading_zeros();
505                writer.write(4, bits as u64)?;
506                if bits > 0 {
507                    writer.write(bits as usize, freq as u64)?;
508                }
509            }
510        }
511
512        Ok(())
513    }
514}
515
516/// Writes a variable-length uint8.
517fn write_var_len_uint8(writer: &mut BitWriter, n: u8) -> Result<()> {
518    if n == 0 {
519        writer.write(1, 0)?;
520    } else {
521        writer.write(1, 1)?;
522        let nbits = 8 - n.leading_zeros();
523        writer.write(3, (nbits - 1) as u64)?;
524        writer.write((nbits - 1) as usize, (n as u64) - (1u64 << (nbits - 1)))?;
525    }
526    Ok(())
527}
528
529/// Floor log2 of a value.
530#[inline]
531pub fn floor_log2_ans(n: u32) -> u32 {
532    if n == 0 { 0 } else { 31 - n.leading_zeros() }
533}
534
535#[inline]
536fn floor_log2(n: u32) -> u32 {
537    floor_log2_ans(n)
538}
539
540/// Precision calculation for frequency encoding.
541///
542/// Determines how many bits of precision to use when encoding a frequency count.
543/// Larger counts can be encoded with less precision.
544///
545/// Matches libjxl's `GetPopulationCountPrecision` from `ans_common.h`.
546pub fn get_population_count_precision(logcount: u32, shift: u32) -> u32 {
547    let logcount_i = logcount as i32;
548    let shift_i = shift as i32;
549    let r = logcount_i.min(shift_i - ((ANS_LOG_TAB_SIZE as i32 - logcount_i) >> 1));
550    r.max(0) as u32
551}
552
553/// Strategy for ANS histogram normalization.
554#[derive(Clone, Copy, Debug, PartialEq, Eq, Default)]
555pub enum ANSHistogramStrategy {
556    /// Only try a few shift values (fastest).
557    Fast,
558    /// Try every other shift value.
559    Approximate,
560    /// Try all shift values (best compression).
561    #[default]
562    Precise,
563}
564
565/// Normalized ANS histogram for encoding.
566///
567/// Contains frequency counts normalized to sum to ANS_TAB_SIZE (4096).
568#[derive(Clone, Debug)]
569pub struct ANSEncodingHistogram {
570    /// Normalized frequency counts.
571    pub counts: Vec<i32>,
572    /// Alphabet size (highest non-zero symbol + 1).
573    pub alphabet_size: usize,
574    /// Cost estimate (header + data bits).
575    pub cost: f32,
576    /// Encoding method:
577    /// - 0: flat distribution
578    /// - 1: small code (1-2 symbols)
579    /// - 2-13: shift value + 1
580    pub method: u32,
581    /// Position of the balancing bin (absorbs rounding error).
582    pub omit_pos: usize,
583    /// Number of unique symbols (for small code).
584    num_symbols: usize,
585    /// Symbol indices (for small code, up to 2).
586    symbols: [usize; 2],
587}
588
589impl ANSEncodingHistogram {
590    /// Create an empty histogram.
591    pub fn new() -> Self {
592        Self {
593            counts: Vec::new(),
594            alphabet_size: 0,
595            cost: f32::MAX,
596            method: 0,
597            omit_pos: 0,
598            num_symbols: 0,
599            symbols: [0, 0],
600        }
601    }
602
603    /// Create from a Histogram with the best normalization.
604    ///
605    /// Tries different shift values and picks the one with lowest cost.
606    pub fn from_histogram(
607        histo: &super::histogram::Histogram,
608        strategy: ANSHistogramStrategy,
609    ) -> Result<Self> {
610        if histo.total_count == 0 {
611            // Empty histogram
612            return Ok(Self {
613                counts: vec![0i32; histo.counts.len().max(1)],
614                alphabet_size: 1,
615                cost: 0.0,
616                method: 0, // Flat
617                omit_pos: 0,
618                num_symbols: 0,
619                symbols: [0, 0],
620            });
621        }
622
623        let alphabet_size = histo.alphabet_size();
624
625        // Count non-zero symbols
626        let mut num_symbols = 0;
627        let mut symbols = [0usize; 2];
628        for (i, &count) in histo.counts.iter().enumerate() {
629            if count > 0 {
630                if num_symbols < 2 {
631                    symbols[num_symbols] = i;
632                }
633                num_symbols += 1;
634            }
635        }
636
637        // Single symbol or two symbols: use small code
638        if num_symbols <= 2 {
639            let mut counts = vec![0i32; alphabet_size];
640            if num_symbols == 1 {
641                counts[symbols[0]] = ANS_TAB_SIZE as i32;
642            } else {
643                // Two symbols: proportional allocation
644                let total = histo.total_count as f64;
645                let count0 = histo.counts[symbols[0]] as f64;
646                let norm0 = ((count0 / total) * ANS_TAB_SIZE as f64).round() as i32;
647                let norm0 = norm0.clamp(1, (ANS_TAB_SIZE - 1) as i32);
648                counts[symbols[0]] = norm0;
649                counts[symbols[1]] = ANS_TAB_SIZE as i32 - norm0;
650            }
651
652            // Cost is just the header
653            let cost = if num_symbols <= 1 { 4.0 } else { 4.0 + 12.0 }; // Approximate
654
655            return Ok(Self {
656                counts,
657                alphabet_size,
658                cost,
659                method: 1, // Small code
660                omit_pos: symbols[0],
661                num_symbols,
662                symbols,
663            });
664        }
665
666        // General case: try different shift values
667        let mut best = Self::new();
668
669        let shifts: Vec<u32> = match strategy {
670            ANSHistogramStrategy::Fast => vec![0, 6, 12],
671            ANSHistogramStrategy::Approximate => (0..=ANS_LOG_TAB_SIZE).step_by(2).collect(),
672            ANSHistogramStrategy::Precise => (0..ANS_LOG_TAB_SIZE).collect(),
673        };
674
675        for shift in shifts {
676            let mut candidate = Self {
677                counts: vec![0i32; alphabet_size],
678                alphabet_size,
679                cost: f32::MAX,
680                method: shift.min(ANS_LOG_TAB_SIZE - 1) + 1,
681                omit_pos: 0,
682                num_symbols,
683                symbols,
684            };
685
686            if candidate.rebalance_histogram(histo, shift) {
687                candidate.cost = candidate.estimate_cost(histo);
688                if candidate.cost < best.cost {
689                    best = candidate;
690                }
691            }
692        }
693
694        if best.cost == f32::MAX {
695            return Err(Error::InvalidHistogram(
696                "Failed to rebalance histogram".to_string(),
697            ));
698        }
699
700        Ok(best)
701    }
702
703    /// Rebalance histogram to sum to ANS_TAB_SIZE with given shift.
704    ///
705    /// Returns true on success, false on failure.
706    ///
707    /// The decoder determines omit_pos as the first symbol with highest logcount,
708    /// then computes its frequency as 4096 - sum(others). We must match this.
709    fn rebalance_histogram(&mut self, histo: &super::histogram::Histogram, shift: u32) -> bool {
710        let total_count = histo.total_count;
711        if total_count == 0 {
712            return false;
713        }
714
715        let norm = ANS_TAB_SIZE as f64 / total_count as f64;
716
717        // First pass: normalize all counts (without precision constraints yet)
718        for (i, &count) in histo.counts.iter().enumerate().take(self.alphabet_size) {
719            if count == 0 {
720                self.counts[i] = 0;
721                continue;
722            }
723
724            let target = count as f64 * norm;
725            let mut normalized = target.round() as i32;
726            normalized = normalized.max(1);
727            normalized = normalized.min((ANS_TAB_SIZE - 1) as i32);
728            self.counts[i] = normalized;
729        }
730
731        // Find the first symbol with maximum logcount - this will be omit_pos
732        let mut max_logcount = 0u32;
733        let mut omit_pos = 0;
734        for (i, &count) in self.counts.iter().enumerate().take(self.alphabet_size) {
735            if count > 0 {
736                let logcount = floor_log2(count as u32) + 1;
737                if logcount > max_logcount {
738                    max_logcount = logcount;
739                    omit_pos = i;
740                }
741            }
742        }
743        self.omit_pos = omit_pos;
744
745        // Second pass: apply precision constraints to all symbols EXCEPT omit_pos
746        let mut running_total = 0i32;
747        for i in 0..self.alphabet_size {
748            if i == omit_pos || self.counts[i] == 0 {
749                if i != omit_pos {
750                    running_total += self.counts[i];
751                }
752                continue;
753            }
754
755            let mut normalized = self.counts[i];
756
757            // Apply precision constraints for non-omit symbols
758            if shift < ANS_LOG_TAB_SIZE && normalized > 1 {
759                let logcount = floor_log2(normalized as u32);
760                let precision = get_population_count_precision(logcount, shift);
761                let drop_bits = logcount.saturating_sub(precision);
762                let mask = (1i32 << drop_bits) - 1;
763                normalized &= !mask;
764                if normalized == 0 {
765                    normalized = 1i32 << drop_bits;
766                }
767            }
768
769            self.counts[i] = normalized;
770            running_total += normalized;
771        }
772
773        // omit_pos frequency is the remainder (this is how decoder computes it)
774        let remainder = ANS_TAB_SIZE as i32 - running_total;
775        if remainder <= 0 || remainder > ANS_TAB_SIZE as i32 {
776            return false;
777        }
778        self.counts[omit_pos] = remainder;
779
780        // Verify omit_pos is the FIRST symbol with the highest logcount.
781        // The decoder re-derives omit_pos by scanning symbols in order and picking
782        // the first one with the maximum logcount. We must ensure that after
783        // rebalancing, no EARLIER symbol has the same or higher logcount.
784        let omit_logcount = floor_log2(self.counts[omit_pos] as u32) + 1;
785        for (i, &count) in self.counts.iter().enumerate().take(self.alphabet_size) {
786            if i == omit_pos {
787                continue;
788            }
789            if count > 0 {
790                let logcount = floor_log2(count as u32) + 1;
791                if logcount > omit_logcount {
792                    // Higher logcount - decoder picks this instead
793                    return false;
794                }
795                if logcount == omit_logcount && i < omit_pos {
796                    // Same logcount but earlier position - decoder picks this instead
797                    return false;
798                }
799            }
800        }
801
802        // Verify sum
803        let sum: i32 = self.counts.iter().sum();
804        sum == ANS_TAB_SIZE as i32
805    }
806
807    /// Estimate encoding cost (header + data bits).
808    fn estimate_cost(&self, histo: &super::histogram::Histogram) -> f32 {
809        // Header cost estimate
810        let header_cost = self.estimate_header_cost();
811
812        // Data cost: entropy of original data with normalized probabilities
813        let data_cost = self.estimate_data_cost(histo);
814
815        header_cost + data_cost
816    }
817
818    /// Estimate header encoding cost.
819    fn estimate_header_cost(&self) -> f32 {
820        if self.method == 0 {
821            // Flat: 2 bits + alphabet size encoding
822            2.0 + 8.0
823        } else if self.num_symbols <= 2 {
824            // Small code
825            if self.num_symbols <= 1 {
826                3.0 + 8.0 // nsym=0: marker + symbol
827            } else {
828                3.0 + 16.0 + 12.0 // nsym=2: marker + 2 symbols + count
829            }
830        } else {
831            // General code: method encoding + alphabet + frequencies
832            let method_bits = 4.0; // Unary + suffix for method
833            let alphabet_bits = 8.0;
834            let freq_bits = self.alphabet_size as f32 * 5.0; // Rough estimate
835            method_bits + alphabet_bits + freq_bits
836        }
837    }
838
839    /// Estimate data encoding cost using normalized frequencies.
840    fn estimate_data_cost(&self, histo: &super::histogram::Histogram) -> f32 {
841        let mut cost = 0.0f32;
842
843        for (i, &count) in histo.counts.iter().enumerate() {
844            if count > 0 {
845                let normalized = self.counts.get(i).copied().unwrap_or(1).max(1);
846                let prob = normalized as f32 / ANS_TAB_SIZE as f32;
847                cost -= count as f32 * prob.log2();
848            }
849        }
850
851        cost
852    }
853
854    /// Write this histogram to a BitWriter.
855    pub fn write(&self, writer: &mut BitWriter) -> Result<()> {
856        if self.method == 0 {
857            // Flat distribution
858            writer.write(1, 0)?; // Non-small
859            writer.write(1, 1)?; // Flat
860            write_var_len_uint8(writer, (self.alphabet_size - 1) as u8)?;
861            return Ok(());
862        }
863
864        if self.num_symbols <= 2 {
865            // Small code
866            writer.write(1, 1)?; // Small tree marker
867            if self.num_symbols == 0 {
868                writer.write(1, 0)?;
869                write_var_len_uint8(writer, 0)?;
870            } else {
871                writer.write(1, (self.num_symbols - 1) as u64)?;
872                for i in 0..self.num_symbols {
873                    write_var_len_uint8(writer, self.symbols[i] as u8)?;
874                }
875                if self.num_symbols == 2 {
876                    writer.write(
877                        ANS_LOG_TAB_SIZE as usize,
878                        self.counts[self.symbols[0]] as u64,
879                    )?;
880                }
881            }
882            return Ok(());
883        }
884
885        // General code
886        self.write_general(writer)
887    }
888
889    /// Write general (non-flat, non-small) histogram.
890    ///
891    /// Matches the format expected by jxl-rs `decode_dist_complex()`.
892    fn write_general(&self, writer: &mut BitWriter) -> Result<()> {
893        writer.write(1, 0)?; // Non-small
894        writer.write(1, 0)?; // Non-flat
895
896        // Encode shift using unary + suffix (method = shift + 1)
897        // Format: len ones, then 0 (unless at max), then len suffix bits
898        let shift = (self.method - 1) as i32;
899        let shift_val = (shift + 1) as u32; // shift+1 is stored, range 1-13
900
901        // Determine unary length
902        let mut len = 0u32;
903        while len < 3 && shift_val >= (1u32 << (len + 1)) {
904            len += 1;
905        }
906
907        // Write unary prefix (len ones)
908        for _ in 0..len {
909            writer.write(1, 1)?;
910        }
911        // Write terminating 0 if len < 3
912        if len < 3 {
913            writer.write(1, 0)?;
914        }
915        // Write suffix bits
916        if len > 0 {
917            let suffix = shift_val - (1u32 << len);
918            writer.write(len as usize, suffix as u64)?;
919        }
920
921        // Encode alphabet size - 3
922        if self.alphabet_size < 3 {
923            return Err(Error::InvalidHistogram(
924                "General histogram needs at least 3 symbols".to_string(),
925            ));
926        }
927        write_var_len_uint8(writer, (self.alphabet_size - 3) as u8)?;
928
929        // Encode log-frequency values for each symbol using the fixed prefix code.
930        // The decoder determines omit_pos as the first symbol with highest logcount.
931        for i in 0..self.alphabet_size {
932            let count = self.counts[i];
933
934            // Compute logcount (0 means freq=0, 1-12 means log2(freq)+1)
935            let logcount = if count <= 0 {
936                0
937            } else {
938                floor_log2(count as u32) + 1
939            };
940
941            // Write the logcount using fixed prefix code
942            let (nbits, code) = LOGCOUNT_PREFIX_CODE[logcount as usize];
943            writer.write(nbits as usize, code as u64)?;
944        }
945
946        // Now write precision bits for each non-zero, non-omit symbol with logcount > 1.
947        // The decoder skips precision bits for omit_pos (highest logcount symbol).
948        for i in 0..self.alphabet_size {
949            if i == self.omit_pos {
950                continue;
951            }
952
953            let count = self.counts[i];
954            if count <= 0 {
955                continue;
956            }
957
958            let logcount = floor_log2(count as u32) + 1;
959            if logcount <= 1 {
960                // logcount=1 means freq=1, no precision bits needed
961                continue;
962            }
963
964            // zeros = logcount - 1 (the log2 of the frequency)
965            let zeros = (logcount - 1) as i32;
966            // bitcount = shift - (12 - zeros) / 2, clamped to [0, zeros]
967            let bitcount = (shift - ((ANS_LOG_TAB_SIZE as i32 - zeros) >> 1)).clamp(0, zeros);
968
969            if bitcount > 0 {
970                // The value stored is: freq = (1 << zeros) + (extra << (zeros - bitcount))
971                // So: extra = (freq - (1 << zeros)) >> (zeros - bitcount)
972                let base = 1i32 << zeros;
973                let extra = ((count - base) >> (zeros - bitcount)) as u32;
974                writer.write(bitcount as usize, extra as u64)?;
975            }
976        }
977
978        Ok(())
979    }
980}
981
982impl Default for ANSEncodingHistogram {
983    fn default() -> Self {
984        Self::new()
985    }
986}
987
988/// Encodes tokens using ANS.
989pub fn encode_tokens_ans(
990    tokens: &[(u32, u32)], // (context, value) pairs
991    distributions: &[AnsDistribution],
992    context_map: &[usize],
993    writer: &mut BitWriter,
994) -> Result<()> {
995    let mut encoder = AnsEncoder::new();
996
997    // Process tokens in reverse order (ANS requirement)
998    for &(context, value) in tokens.iter().rev() {
999        let dist_idx = context_map.get(context as usize).copied().unwrap_or(0);
1000        if let Some(dist) = distributions.get(dist_idx)
1001            && let Some(info) = dist.get(value as usize)
1002        {
1003            encoder.put_symbol(info);
1004        }
1005    }
1006
1007    encoder.finalize(writer)
1008}
1009
1010#[cfg(test)]
1011mod tests {
1012    use super::*;
1013    use crate::entropy_coding::histogram::Histogram;
1014
1015    #[test]
1016    fn test_ans_encoding_histogram_single_symbol() {
1017        let h = Histogram::from_counts(&[100, 0, 0, 0]);
1018        let encoded = ANSEncodingHistogram::from_histogram(&h, ANSHistogramStrategy::Fast).unwrap();
1019
1020        assert_eq!(encoded.num_symbols, 1);
1021        assert_eq!(encoded.method, 1); // Small code
1022        assert_eq!(encoded.counts[0], ANS_TAB_SIZE as i32);
1023        assert!(encoded.cost < 100.0); // Header only
1024    }
1025
1026    #[test]
1027    fn test_ans_encoding_histogram_two_symbols() {
1028        let h = Histogram::from_counts(&[100, 100, 0, 0]);
1029        let encoded = ANSEncodingHistogram::from_histogram(&h, ANSHistogramStrategy::Fast).unwrap();
1030
1031        assert_eq!(encoded.num_symbols, 2);
1032        assert_eq!(encoded.method, 1); // Small code
1033        // Should split roughly 50/50
1034        let sum: i32 = encoded.counts.iter().sum();
1035        assert_eq!(sum, ANS_TAB_SIZE as i32);
1036        assert!(encoded.counts[0] > 0);
1037        assert!(encoded.counts[1] > 0);
1038    }
1039
1040    #[test]
1041    fn test_ans_encoding_histogram_general() {
1042        let h = Histogram::from_counts(&[100, 50, 25, 10, 5, 3, 2, 1]);
1043        let encoded = ANSEncodingHistogram::from_histogram(&h, ANSHistogramStrategy::Fast).unwrap();
1044
1045        // Should use general code (more than 2 symbols)
1046        assert!(encoded.method >= 2 || encoded.method == 0);
1047
1048        // Sum should be exactly ANS_TAB_SIZE
1049        let sum: i32 = encoded.counts.iter().sum();
1050        assert_eq!(sum, ANS_TAB_SIZE as i32);
1051
1052        // All non-zero original counts should have non-zero normalized counts
1053        for (i, &orig) in h.counts.iter().enumerate() {
1054            if orig > 0 {
1055                assert!(
1056                    encoded.counts.get(i).copied().unwrap_or(0) > 0,
1057                    "Symbol {} had count {} but normalized to 0",
1058                    i,
1059                    orig
1060                );
1061            }
1062        }
1063    }
1064
1065    #[test]
1066    fn test_ans_encoding_histogram_empty() {
1067        let h = Histogram::new();
1068        let encoded = ANSEncodingHistogram::from_histogram(&h, ANSHistogramStrategy::Fast).unwrap();
1069
1070        assert_eq!(encoded.cost, 0.0);
1071        assert_eq!(encoded.method, 0); // Flat
1072    }
1073
1074    #[test]
1075    fn test_get_population_count_precision() {
1076        // logcount=0, any shift: precision=0
1077        assert_eq!(get_population_count_precision(0, 12), 0);
1078
1079        // logcount=12, shift=12: min(12, 12-(0/2)) = 12
1080        assert_eq!(get_population_count_precision(12, 12), 12);
1081
1082        // logcount=6, shift=6: min(6, 6-3) = 3
1083        assert_eq!(get_population_count_precision(6, 6), 3);
1084
1085        // logcount=1, shift=0: min(1, 0-(11/2)) = min(1, -5) = 0 (clamped)
1086        assert_eq!(get_population_count_precision(1, 0), 0);
1087    }
1088
1089    #[test]
1090    fn test_ans_encoding_histogram_write() {
1091        let h = Histogram::from_counts(&[100, 0, 0, 0]);
1092        let encoded = ANSEncodingHistogram::from_histogram(&h, ANSHistogramStrategy::Fast).unwrap();
1093
1094        let mut writer = BitWriter::new();
1095        encoded.write(&mut writer).unwrap();
1096
1097        let bytes = writer.finish_with_padding();
1098        assert!(!bytes.is_empty());
1099    }
1100
1101    #[test]
1102    fn test_flat_distribution() {
1103        let dist = AnsDistribution::flat(16).unwrap();
1104        assert_eq!(dist.alphabet_size(), 16);
1105
1106        // All frequencies should be 256 (4096 / 16)
1107        for sym in &dist.symbols {
1108            assert_eq!(sym.freq, 256);
1109        }
1110    }
1111
1112    #[test]
1113    fn test_from_frequencies() {
1114        let freqs = vec![100, 200, 300, 400];
1115        let dist = AnsDistribution::from_frequencies(&freqs).unwrap();
1116        assert_eq!(dist.alphabet_size(), 4);
1117
1118        // Total should be ANS_TAB_SIZE
1119        let total: u32 = dist.symbols.iter().map(|s| s.freq as u32).sum();
1120        assert_eq!(total, ANS_TAB_SIZE);
1121    }
1122
1123    #[test]
1124    fn test_ans_encoder_basic() {
1125        let dist = AnsDistribution::flat(4).unwrap();
1126        let mut encoder = AnsEncoder::new();
1127
1128        // Encode a few symbols
1129        encoder.put_symbol(&dist.symbols[0]);
1130        encoder.put_symbol(&dist.symbols[1]);
1131        encoder.put_symbol(&dist.symbols[2]);
1132
1133        // State should have changed from initial
1134        assert_ne!(encoder.state(), ANS_SIGNATURE << 16);
1135    }
1136
1137    #[test]
1138    fn test_reverse_map() {
1139        let dist = AnsDistribution::flat(4).unwrap();
1140
1141        // Each symbol should have freq entries in reverse_map
1142        for sym in &dist.symbols {
1143            assert_eq!(sym.reverse_map.len(), sym.freq as usize);
1144        }
1145
1146        // All positions 0..4096 should be covered exactly once
1147        let mut covered = vec![false; ANS_TAB_SIZE as usize];
1148        for sym in &dist.symbols {
1149            for &pos in &sym.reverse_map {
1150                assert!(!covered[pos as usize], "position {} covered twice", pos);
1151                covered[pos as usize] = true;
1152            }
1153        }
1154        assert!(covered.iter().all(|&c| c), "not all positions covered");
1155    }
1156
1157    #[test]
1158    fn test_write_distribution() {
1159        let dist = AnsDistribution::flat(16).unwrap();
1160        let mut writer = BitWriter::new();
1161        dist.write(&mut writer).unwrap();
1162
1163        let bytes = writer.finish_with_padding();
1164        // Should produce some output
1165        assert!(!bytes.is_empty());
1166    }
1167
1168    #[test]
1169    fn test_ans_roundtrip_manual() {
1170        // Create a simple flat distribution
1171        let dist = AnsDistribution::flat(2).unwrap();
1172
1173        println!("Distribution: {} symbols", dist.alphabet_size());
1174        for (i, sym) in dist.symbols.iter().enumerate() {
1175            println!("  Symbol {}: freq={}", i, sym.freq);
1176        }
1177
1178        // Encode symbol 0
1179        let mut encoder = AnsEncoder::new();
1180        let initial_state = encoder.state();
1181        println!("\nInitial state: 0x{:08x}", initial_state);
1182        assert_eq!(initial_state, 0x130000, "Initial state should be 0x130000");
1183
1184        let info = &dist.symbols[0];
1185        encoder.put_symbol(info);
1186        let encoded_state = encoder.state();
1187        println!("After encoding symbol 0: state=0x{:08x}", encoded_state);
1188
1189        // Now manually decode to verify
1190        let idx = encoded_state & 0xFFF;
1191        println!("Decode: idx = {}", idx);
1192
1193        // For flat distribution of 2, each has freq 2048
1194        // Symbol 0: cumul=0, freq=2048 -> positions [0, 2048)
1195        // Symbol 1: cumul=2048, freq=2048 -> positions [2048, 4096)
1196        let decoded_symbol = if idx < 2048 { 0 } else { 1 };
1197        let offset_in_symbol = if idx < 2048 { idx } else { idx - 2048 };
1198        let freq = 2048u32;
1199
1200        println!("Decoded symbol: {}", decoded_symbol);
1201        println!("Offset in symbol: {}", offset_in_symbol);
1202
1203        // The decoder does: next_state = (state >> 12) * freq + offset
1204        let quotient = encoded_state >> 12;
1205        let next_state = quotient * freq + offset_in_symbol;
1206        println!(
1207            "next_state = {} * {} + {} = 0x{:08x}",
1208            quotient, freq, offset_in_symbol, next_state
1209        );
1210
1211        // The next_state should be the initial state (0x130000)
1212        assert_eq!(next_state, 0x130000, "Decoded state should be 0x130000");
1213        assert_eq!(decoded_symbol, 0, "Decoded symbol should be 0");
1214    }
1215
1216    #[test]
1217    fn test_ans_roundtrip_multiple_symbols() {
1218        use crate::bit_writer::BitWriter;
1219        use crate::entropy_coding::ans_decode::{AnsHistogram, AnsReader, BitReader};
1220
1221        // Test encoding multiple symbols and verify they can be decoded
1222        // using the jxl-rs compatible decoder (alias table method)
1223
1224        // Create a flat distribution with 4 symbols (each freq = 1024)
1225        let counts = [1024i32, 1024, 1024, 1024];
1226        let dist = AnsDistribution::from_normalized_counts(&counts).unwrap();
1227
1228        let symbols_to_encode: Vec<usize> = vec![0, 1, 2, 3, 0, 1];
1229        println!(
1230            "Encoding {} symbols: {:?}",
1231            symbols_to_encode.len(),
1232            symbols_to_encode
1233        );
1234
1235        // Encode in reverse order (as ANS requires)
1236        let mut encoder = AnsEncoder::new();
1237        for &sym in symbols_to_encode.iter().rev() {
1238            encoder.put_symbol(&dist.symbols[sym]);
1239        }
1240
1241        println!("Final state after encoding: 0x{:08x}", encoder.state());
1242
1243        // Finalize encoder to bitstream
1244        let mut writer = BitWriter::new();
1245        encoder.finalize(&mut writer).unwrap();
1246        let encoded_bytes = writer.finish_with_padding();
1247        println!("Encoded bytes: {:02x?}", encoded_bytes);
1248
1249        // Build decoder histogram by writing and reading back
1250        let ans_histo = ANSEncodingHistogram::from_histogram(
1251            &Histogram::from_counts(&counts),
1252            ANSHistogramStrategy::Precise,
1253        )
1254        .unwrap();
1255        let mut hist_writer = BitWriter::new();
1256        ans_histo.write(&mut hist_writer).unwrap();
1257        let hist_bytes = hist_writer.finish_with_padding();
1258
1259        let mut hist_br = BitReader::new(&hist_bytes);
1260        let decoded_hist = AnsHistogram::decode(&mut hist_br, 6).unwrap();
1261
1262        println!(
1263            "Decoded histogram frequencies: {:?}",
1264            &decoded_hist.frequencies[..4]
1265        );
1266
1267        // Decode using jxl-rs compatible decoder
1268        let mut br = BitReader::new(&encoded_bytes);
1269        let mut ans_reader = AnsReader::init(&mut br).unwrap();
1270
1271        println!("Decoding:");
1272        let mut decoded = Vec::new();
1273        for i in 0..symbols_to_encode.len() {
1274            let symbol = decoded_hist.read(&mut br, &mut ans_reader.0) as usize;
1275            println!(
1276                "  step {}: symbol={}, state=0x{:08x}",
1277                i, symbol, ans_reader.0
1278            );
1279            decoded.push(symbol);
1280        }
1281
1282        println!("Original: {:?}", symbols_to_encode);
1283        println!("Decoded:  {:?}", decoded);
1284        println!("Final state: 0x{:08x}", ans_reader.0);
1285
1286        assert_eq!(
1287            decoded, symbols_to_encode,
1288            "Decoded symbols should match original"
1289        );
1290        assert!(
1291            ans_reader.check_final_state().is_ok(),
1292            "Final state should be 0x130000, got 0x{:08x}",
1293            ans_reader.0
1294        );
1295    }
1296
1297    #[test]
1298    fn test_ans_histogram_write_decode_roundtrip() {
1299        use crate::bit_writer::BitWriter;
1300        use crate::entropy_coding::histogram::Histogram;
1301
1302        // Create a histogram with several symbols
1303        let histo = Histogram::from_counts(&[100, 50, 25, 10]);
1304
1305        let encoded =
1306            ANSEncodingHistogram::from_histogram(&histo, ANSHistogramStrategy::Precise).unwrap();
1307
1308        println!("Histogram: {:?}", histo.counts);
1309        println!("Encoded counts: {:?}", encoded.counts);
1310        println!(
1311            "Method: {}, alphabet_size: {}, omit_pos: {}",
1312            encoded.method, encoded.alphabet_size, encoded.omit_pos
1313        );
1314
1315        // Verify sum is 4096
1316        let sum: i32 = encoded.counts.iter().sum();
1317        assert_eq!(sum, ANS_TAB_SIZE as i32, "Sum should be 4096");
1318
1319        // Write to bitstream
1320        let mut writer = BitWriter::new();
1321        encoded.write(&mut writer).unwrap();
1322        let bytes = writer.finish_with_padding();
1323
1324        println!("Encoded histogram: {} bytes", bytes.len());
1325        println!("Bytes: {:02x?}", bytes);
1326    }
1327}