Skip to main content

jxl_encoder/entropy_coding/
ans.rs

1// Copyright (c) Imazen LLC and the JPEG XL Project Authors.
2// Algorithms and constants derived from libjxl (BSD-3-Clause).
3// Licensed under AGPL-3.0-or-later. Commercial licenses at https://www.imazen.io/pricing
4
5//! ANS (Asymmetric Numeral Systems) encoder.
6//!
7//! ANS is an entropy coding method used in JPEG XL for efficient symbol
8//! compression. This module implements the rANS (range ANS) variant.
9
10use crate::bit_writer::BitWriter;
11use crate::error::{Error, Result};
12
13/// ANS table size (2^12 = 4096).
14pub const ANS_LOG_TAB_SIZE: u32 = 12;
15pub const ANS_TAB_SIZE: u32 = 1 << ANS_LOG_TAB_SIZE;
16pub const ANS_TAB_MASK: u32 = ANS_TAB_SIZE - 1;
17
18/// Maximum alphabet size for ANS.
19pub const ANS_MAX_ALPHABET_SIZE: usize = 256;
20
21/// Initial state marker.
22pub const ANS_SIGNATURE: u32 = 0x13;
23
24/// RLE marker symbol in logcount prefix code.
25const RLE_MARKER_SYM: u8 = 13;
26
27/// Prefix code table for encoding log-frequency values (0-13).
28/// Format: (nbits, code_lsb) - the code to write for each symbol.
29/// This is the inverse of the decoder's lookup table in jxl-rs.
30const LOGCOUNT_PREFIX_CODE: [(u8, u8); 14] = [
31    (5, 0b10001),   // 0: freq=0 (but we use 0 for zero, not logcount 0)
32    (4, 0b1011),    // 1: logcount=1, freq=1
33    (4, 0b1111),    // 2: logcount=2, freq in [2,3]
34    (4, 0b0011),    // 3: logcount=3, freq in [4,7]
35    (4, 0b1001),    // 4: logcount=4, freq in [8,15]
36    (4, 0b0111),    // 5: logcount=5, freq in [16,31]
37    (3, 0b100),     // 6: logcount=6, freq in [32,63]
38    (3, 0b010),     // 7: logcount=7, freq in [64,127]
39    (3, 0b101),     // 8: logcount=8, freq in [128,255]
40    (3, 0b110),     // 9: logcount=9, freq in [256,511]
41    (3, 0b000),     // 10: logcount=10, freq in [512,1023]
42    (6, 0b100001),  // 11: logcount=11, freq in [1024,2047]
43    (7, 0b0000001), // 12: logcount=12, freq in [2048,4095]
44    (7, 0b1000001), // 13: RLE marker
45];
46
47/// Build sorted table of all representable count values for a given shift.
48/// Matches libjxl's AllowedCounts precomputation (enc_ans.cc:581-615).
49/// Returns counts in DECREASING order (index 0 = highest count).
50fn build_allowed_counts(shift: u32) -> Vec<i32> {
51    let mut counts = Vec::with_capacity(256);
52    // Count = 1 is always representable (logcount=1, no precision bits)
53    counts.push(1i32);
54    for bits in 1..ANS_LOG_TAB_SIZE {
55        let precision = get_population_count_precision(bits, shift);
56        let drop_bits = bits.saturating_sub(precision);
57        let num_mantissa = 1u32 << precision;
58        for mantissa in 0..num_mantissa {
59            let count = (1i32 << bits) | ((mantissa as i32) << drop_bits);
60            if count > 0 && count < ANS_TAB_SIZE as i32 {
61                counts.push(count);
62            }
63        }
64    }
65    counts.sort_unstable();
66    counts.dedup();
67    counts.reverse(); // Decreasing order: index 0 = highest
68    counts
69}
70
71/// Precomputed allowed counts tables for all shift values 0..=ANS_LOG_TAB_SIZE.
72/// These tables are deterministic (depend only on shift value) and can be
73/// computed once and reused across all histogram normalization calls.
74pub struct AllowedCountsCache {
75    // 13 entries: shifts 0 through 12 inclusive (ANS_LOG_TAB_SIZE = 12).
76    tables: [Vec<i32>; ANS_LOG_TAB_SIZE as usize + 1],
77}
78
79impl Default for AllowedCountsCache {
80    fn default() -> Self {
81        Self::new()
82    }
83}
84
85impl AllowedCountsCache {
86    /// Build all 13 allowed counts tables (one per shift value, 0..=12).
87    pub fn new() -> Self {
88        Self {
89            tables: core::array::from_fn(|shift| build_allowed_counts(shift as u32)),
90        }
91    }
92
93    /// Get the allowed counts table for a given shift.
94    #[inline]
95    pub fn get(&self, shift: u32) -> &[i32] {
96        &self.tables[shift as usize]
97    }
98}
99
100/// Find the index of the highest allowed count <= target in a decreasing-order table.
101/// Snaps DOWN to prevent rest from going negative (matches libjxl's mask-off behavior).
102/// If target < smallest allowed (1), returns the last index (count=1).
103fn find_allowed_leq(allowed: &[i32], target: i32) -> usize {
104    // Binary search in decreasing order: find first index where allowed[i] <= target
105    let mut lo = 0usize;
106    let mut hi = allowed.len();
107    while lo < hi {
108        let mid = lo + (hi - lo) / 2;
109        if allowed[mid] > target {
110            lo = mid + 1;
111        } else {
112            hi = mid;
113        }
114    }
115    // lo is the first index where allowed[lo] <= target
116    if lo >= allowed.len() {
117        allowed.len() - 1 // target < 1, snap to minimum (1)
118    } else {
119        lo
120    }
121}
122
123/// Estimate data cost of encoding `histo` using ANS with normalized `counts`.
124/// Matches libjxl's `EstimateDataBits` (enc_ans.cc:362-370).
125/// `cost = total * ANS_LOG_TAB_SIZE - sum(actual_count * log2(norm_count))`
126fn estimate_data_bits_normalized(
127    histo_counts: &[i32],
128    norm_counts: &[i32],
129    total_count: usize,
130    alphabet_size: usize,
131) -> f64 {
132    let mut sum = 0.0f64;
133    for (actual, norm) in histo_counts
134        .iter()
135        .zip(norm_counts.iter())
136        .take(alphabet_size)
137    {
138        if *actual > 0 && *norm > 0 {
139            sum += *actual as f64 * jxl_simd::fast_log2f(*norm as f32) as f64;
140        }
141    }
142    total_count as f64 * ANS_LOG_TAB_SIZE as f64 - sum
143}
144
145/// Precision for reciprocal multiplication (avoids division).
146const RECIPROCAL_PRECISION: u32 = 44;
147
148/// Symbol information for ANS encoding.
149#[derive(Debug, Clone)]
150pub struct AnsEncSymbolInfo {
151    /// Normalized frequency (1 to ANS_TAB_SIZE).
152    pub freq: u16,
153    /// Reciprocal of frequency for fast division: ceil((1 << 44) / freq).
154    pub ifreq: u64,
155    /// Maps remainder values to table offsets.
156    pub reverse_map: Vec<u16>,
157}
158
159impl AnsEncSymbolInfo {
160    /// Creates symbol info for a given frequency.
161    pub fn new(freq: u16) -> Self {
162        let ifreq = if freq > 0 {
163            (1u64 << RECIPROCAL_PRECISION).div_ceil(freq as u64)
164        } else {
165            0
166        };
167
168        Self {
169            freq,
170            ifreq,
171            reverse_map: Vec::new(), // Allocated later in build_reverse_maps
172        }
173    }
174}
175
176/// ANS encoder state.
177pub struct AnsEncoder {
178    /// ANS state (normalized to range).
179    state: u32,
180    /// Accumulated output bits (stored reversed).
181    bits: Vec<(u32, u8)>, // (bits, nbits)
182}
183
184impl AnsEncoder {
185    /// Creates a new ANS encoder.
186    pub fn new() -> Self {
187        Self {
188            state: ANS_SIGNATURE << 16,
189            bits: Vec::new(),
190        }
191    }
192
193    /// Creates a new ANS encoder with pre-allocated capacity for `num_tokens` tokens.
194    pub fn with_capacity(num_tokens: usize) -> Self {
195        Self {
196            state: ANS_SIGNATURE << 16,
197            bits: Vec::with_capacity(num_tokens * 2), // ~2 entries per token (symbol + extra bits)
198        }
199    }
200
201    /// Encodes a single symbol using precomputed symbol info.
202    ///
203    /// Returns the bits that should be output (if any).
204    #[inline]
205    pub fn put_symbol(&mut self, info: &AnsEncSymbolInfo) {
206        let freq = info.freq as u32;
207
208        // Renormalization: if state is too large, emit 16 bits
209        if (self.state >> (32 - ANS_LOG_TAB_SIZE)) >= freq {
210            self.bits.push((self.state & 0xFFFF, 16));
211            self.state >>= 16;
212        }
213
214        // State update using multiplication-by-reciprocal
215        // v = state / freq (approximately)
216        let v = ((self.state as u64 * info.ifreq) >> RECIPROCAL_PRECISION) as u32;
217        let remainder = self.state - v * freq;
218
219        // Look up offset in reverse map
220        let offset = info.reverse_map[remainder as usize] as u32;
221
222        // Update state
223        self.state = (v << ANS_LOG_TAB_SIZE) + offset;
224    }
225
226    /// Pushes extra bits into the encoder's output buffer.
227    ///
228    /// Used for HybridUint extra bits that are interleaved with ANS symbols.
229    /// These bits are stored in the same reversed buffer and will be emitted
230    /// in proper order during finalize().
231    #[inline]
232    pub fn push_bits(&mut self, bits: u32, nbits: u8) {
233        if nbits > 0 {
234            self.bits.push((bits, nbits));
235        }
236    }
237
238    /// Finalizes encoding and writes to a BitWriter.
239    ///
240    /// Writes the final state followed by all accumulated bits in reverse order.
241    pub fn finalize(self, writer: &mut BitWriter) -> Result<()> {
242        // Debug: show final state
243        #[cfg(feature = "debug-tokens")]
244        eprintln!(
245            "ANS finalize: state=0x{:08x}, {} bit chunks",
246            self.state,
247            self.bits.len()
248        );
249
250        // Write final state (32 bits)
251        writer.write(32, self.state as u64)?;
252
253        // Write accumulated bits in reverse order
254        for &(bits, nbits) in self.bits.iter().rev() {
255            writer.write(nbits as usize, bits as u64)?;
256        }
257
258        Ok(())
259    }
260
261    /// Returns the current state.
262    pub fn state(&self) -> u32 {
263        self.state
264    }
265}
266
267impl Default for AnsEncoder {
268    fn default() -> Self {
269        Self::new()
270    }
271}
272
273/// A complete ANS distribution with encoding info for all symbols.
274#[derive(Debug, Clone)]
275pub struct AnsDistribution {
276    /// Symbol encoding information.
277    pub symbols: Vec<AnsEncSymbolInfo>,
278    /// Log2 of distribution size (typically 12).
279    pub log_alpha_size: u32,
280    /// Total of normalized frequencies (should be ANS_TAB_SIZE).
281    pub total: u32,
282}
283
284impl AnsDistribution {
285    /// Creates a distribution from raw frequencies.
286    ///
287    /// Normalizes frequencies to sum to ANS_TAB_SIZE.
288    pub fn from_frequencies(freqs: &[u32]) -> Result<Self> {
289        if freqs.is_empty() {
290            return Err(Error::InvalidHistogram("empty distribution".to_string()));
291        }
292
293        let total_count: u64 = freqs.iter().map(|&f| f as u64).sum();
294        if total_count == 0 {
295            return Err(Error::InvalidHistogram("all zero frequencies".to_string()));
296        }
297
298        // Normalize frequencies to sum to ANS_TAB_SIZE
299        let mut normalized: Vec<u16> = Vec::with_capacity(freqs.len());
300        let mut running_total: u32 = 0;
301
302        for &freq in freqs.iter() {
303            let normalized_freq = if freq == 0 {
304                0
305            } else {
306                // Scale to ANS_TAB_SIZE, ensuring at least 1 for non-zero
307                ((freq as u64 * ANS_TAB_SIZE as u64) / total_count).max(1) as u16
308            };
309            normalized.push(normalized_freq);
310            running_total += normalized_freq as u32;
311        }
312
313        // Adjust to exactly sum to ANS_TAB_SIZE
314        let diff = running_total as i32 - ANS_TAB_SIZE as i32;
315        if diff != 0 {
316            // Find largest frequency and adjust it
317            if let Some((max_idx, _)) = normalized
318                .iter()
319                .enumerate()
320                .filter(|&(_, &f)| f > 0)
321                .max_by_key(|&(_, &f)| f)
322            {
323                let new_val = (normalized[max_idx] as i32 - diff).max(1) as u16;
324                normalized[max_idx] = new_val;
325            }
326        }
327
328        // Build symbol info with reverse maps
329        let mut symbols: Vec<AnsEncSymbolInfo> = normalized
330            .iter()
331            .map(|&f| AnsEncSymbolInfo::new(f))
332            .collect();
333
334        // Build reverse map (alias table) using default log_alpha_size for this alphabet
335        let log_alpha_size = Self::default_log_alpha_size(symbols.len());
336        Self::build_reverse_maps(&mut symbols, log_alpha_size)?;
337
338        Ok(Self {
339            symbols,
340            log_alpha_size: ANS_LOG_TAB_SIZE,
341            total: ANS_TAB_SIZE,
342        })
343    }
344
345    /// Creates a flat (uniform) distribution.
346    pub fn flat(alphabet_size: usize) -> Result<Self> {
347        if alphabet_size == 0 || alphabet_size > ANS_TAB_SIZE as usize {
348            return Err(Error::InvalidHistogram(format!(
349                "invalid alphabet size: {}",
350                alphabet_size
351            )));
352        }
353
354        let base_freq = ANS_TAB_SIZE as usize / alphabet_size;
355        let remainder = ANS_TAB_SIZE as usize % alphabet_size;
356
357        let mut freqs = vec![base_freq as u32; alphabet_size];
358        for freq in freqs.iter_mut().take(remainder) {
359            *freq += 1;
360        }
361
362        Self::from_frequencies(&freqs)
363    }
364
365    /// Creates a distribution from pre-normalized counts.
366    ///
367    /// The counts must already sum to ANS_TAB_SIZE (4096). This is used
368    /// when building distributions from ANSEncodingHistogram which has
369    /// already done the normalization.
370    pub fn from_normalized_counts(counts: &[i32]) -> Result<Self> {
371        let log_alpha_size = Self::default_log_alpha_size(counts.len());
372        Self::from_normalized_counts_with_log_alpha(counts, log_alpha_size)
373    }
374
375    /// Creates a distribution from pre-normalized counts with explicit log_alpha_size.
376    ///
377    /// Use this when multiple distributions share a single header (e.g., multi-histogram
378    /// ANS). The `log_alpha_size` must match the value written to the bitstream header,
379    /// NOT the per-distribution default. The decoder reads one global log_alpha_size
380    /// and uses it for all distributions in the group.
381    pub fn from_normalized_counts_with_log_alpha(
382        counts: &[i32],
383        log_alpha_size: usize,
384    ) -> Result<Self> {
385        if counts.is_empty() {
386            return Err(Error::InvalidHistogram("empty distribution".to_string()));
387        }
388
389        // Verify sum
390        let total: i32 = counts.iter().sum();
391        if total != ANS_TAB_SIZE as i32 {
392            return Err(Error::InvalidHistogram(format!(
393                "normalized counts sum to {} instead of {}",
394                total, ANS_TAB_SIZE
395            )));
396        }
397
398        // Build symbol info
399        let mut symbols: Vec<AnsEncSymbolInfo> = counts
400            .iter()
401            .map(|&c| AnsEncSymbolInfo::new(c.max(0) as u16))
402            .collect();
403
404        // Build reverse maps with the caller-specified log_alpha_size
405        Self::build_reverse_maps(&mut symbols, log_alpha_size)?;
406
407        Ok(Self {
408            symbols,
409            log_alpha_size: ANS_LOG_TAB_SIZE,
410            total: ANS_TAB_SIZE,
411        })
412    }
413
414    /// Computes the default log_alpha_size for a given alphabet size.
415    ///
416    /// This is the value that would be written to the bitstream header for a
417    /// standalone distribution. For multi-histogram contexts, use the global
418    /// log_alpha_size from the header instead.
419    fn default_log_alpha_size(alphabet_size: usize) -> usize {
420        use super::encode_ans::ANS_LOG_ALPHA_SIZE;
421        if alphabet_size <= (1 << ANS_LOG_ALPHA_SIZE) {
422            ANS_LOG_ALPHA_SIZE
423        } else {
424            let min_bits = if alphabet_size <= 1 {
425                5
426            } else {
427                (alphabet_size - 1).ilog2() as usize + 1
428            };
429            min_bits.clamp(5, 8)
430        }
431    }
432
433    /// Builds reverse maps for all symbols using the alias table method.
434    ///
435    /// The decoder uses an alias table with buckets. Each idx in [0, 4096) maps to some
436    /// symbol and an offset within that symbol's range. The encoder needs to know,
437    /// for each symbol s and remainder r in [0, freq[s]), what idx to output.
438    ///
439    /// This exactly mirrors the decoder's build_alias_map and read methods.
440    ///
441    /// `log_alpha_size` MUST match the value written to the bitstream header, since
442    /// the decoder uses it to split 12-bit indices into (bucket, position) pairs.
443    /// When multiple distributions share a header, they all use the same global
444    /// log_alpha_size — passing a per-distribution value causes alias table mismatch.
445    fn build_reverse_maps(symbols: &mut [AnsEncSymbolInfo], log_alpha_size: usize) -> Result<()> {
446        let alphabet_size = symbols.len();
447        if alphabet_size == 0 {
448            return Ok(());
449        }
450
451        // Verify frequencies sum to ANS_TAB_SIZE
452        let total: u32 = symbols.iter().map(|s| s.freq as u32).sum();
453        if total != ANS_TAB_SIZE {
454            return Err(Error::InvalidHistogram(format!(
455                "frequencies sum to {} instead of {}",
456                total, ANS_TAB_SIZE
457            )));
458        }
459
460        // Special case: single-symbol distribution
461        // jxl-rs uses a simplified alias table where offset = idx for all positions.
462        // This means reverse_map[r] = r (identity mapping) for the single symbol.
463        if let Some(single_sym_idx) = symbols.iter().position(|s| s.freq == ANS_TAB_SIZE as u16) {
464            // Clear all reverse maps
465            for sym in symbols.iter_mut() {
466                sym.reverse_map.clear();
467            }
468            // Set identity mapping for the single symbol
469            let map = &mut symbols[single_sym_idx].reverse_map;
470            map.resize(ANS_TAB_SIZE as usize, 0);
471            for (i, v) in map.iter_mut().enumerate() {
472                *v = i as u16;
473            }
474            return Ok(());
475        }
476
477        let table_size = 1usize << log_alpha_size;
478        let log_bucket_size = ANS_LOG_TAB_SIZE as usize - log_alpha_size;
479        let bucket_size = 1u16 << log_bucket_size;
480
481        // Working bucket structure matching jxl-rs
482        #[derive(Clone)]
483        #[allow(dead_code)]
484        struct WorkingBucket {
485            dist: u16,         // Frequency of primary symbol
486            alias_symbol: u16, // Alias symbol (used when pos >= cutoff)
487            alias_offset: u16, // Offset for alias symbol
488            alias_cutoff: u16, // Positions [0, cutoff) use primary, [cutoff, bucket_size) use alias
489        }
490
491        let mut buckets: Vec<WorkingBucket> = (0..table_size)
492            .map(|i| {
493                let dist = if i < alphabet_size {
494                    symbols[i].freq
495                } else {
496                    0
497                };
498                WorkingBucket {
499                    dist,
500                    alias_symbol: if i < alphabet_size { i as u16 } else { 0 },
501                    alias_offset: 0,
502                    alias_cutoff: dist,
503                }
504            })
505            .collect();
506
507        // Separate into underfull and overfull
508        let mut underfull: Vec<usize> = Vec::with_capacity(table_size);
509        let mut overfull: Vec<usize> = Vec::with_capacity(table_size);
510        for (i, bucket) in buckets.iter().enumerate() {
511            if bucket.alias_cutoff < bucket_size {
512                underfull.push(i);
513            } else if bucket.alias_cutoff > bucket_size {
514                overfull.push(i);
515            }
516        }
517
518        // Alias redistribution - exactly matching jxl-rs
519        while let (Some(o), Some(u)) = (overfull.pop(), underfull.pop()) {
520            let by = bucket_size - buckets[u].alias_cutoff;
521            buckets[o].alias_cutoff -= by;
522            buckets[u].alias_symbol = o as u16;
523            buckets[u].alias_offset = buckets[o].alias_cutoff;
524
525            match buckets[o].alias_cutoff.cmp(&bucket_size) {
526                std::cmp::Ordering::Less => underfull.push(o),
527                std::cmp::Ordering::Greater => overfull.push(o),
528                std::cmp::Ordering::Equal => {}
529            }
530        }
531
532        // Pre-allocate reverse maps with exact sizes (offset is 0..freq-1)
533        for sym in symbols.iter_mut() {
534            sym.reverse_map.clear();
535            sym.reverse_map.resize(sym.freq as usize, 0);
536        }
537
538        // For each idx in [0, 4096), simulate the decoder to find which symbol it decodes to
539        // and what offset within that symbol's range. Write directly into reverse_map[offset].
540        for idx in 0..ANS_TAB_SIZE {
541            let bucket_idx = (idx >> log_bucket_size) as usize;
542            let pos = (idx as u16) & (bucket_size - 1);
543
544            let bucket = &buckets[bucket_idx.min(table_size - 1)];
545            let alias_cutoff = bucket.alias_cutoff;
546
547            let (symbol, offset) = if pos < alias_cutoff {
548                (bucket_idx, pos)
549            } else {
550                let alias_sym = bucket.alias_symbol as usize;
551                let offset = bucket.alias_offset - alias_cutoff + pos;
552                (alias_sym, offset)
553            };
554
555            if symbol < alphabet_size {
556                symbols[symbol].reverse_map[offset as usize] = idx as u16;
557            }
558        }
559
560        Ok(())
561    }
562
563    /// Returns the number of symbols in this distribution.
564    pub fn alphabet_size(&self) -> usize {
565        self.symbols.len()
566    }
567
568    /// Gets the encoding info for a symbol.
569    pub fn get(&self, symbol: usize) -> Option<&AnsEncSymbolInfo> {
570        self.symbols.get(symbol)
571    }
572
573    /// Writes this distribution to a BitWriter.
574    pub fn write(&self, writer: &mut BitWriter) -> Result<()> {
575        // Check if this is a flat distribution
576        let is_flat = self.is_flat();
577
578        writer.write(1, 0)?; // Non-small tree marker
579        writer.write(1, u64::from(is_flat))?;
580
581        if is_flat {
582            // Flat distribution: just encode alphabet size
583            write_var_len_uint8(writer, (self.alphabet_size() - 1) as u8)?;
584        } else {
585            // General distribution encoding
586            // For now, use the simplest encoding (shift = 0, meaning power-of-2 counts)
587            self.write_general(writer)?;
588        }
589
590        Ok(())
591    }
592
593    /// Checks if this is a flat (uniform) distribution.
594    fn is_flat(&self) -> bool {
595        let first_freq = self.symbols.first().map(|s| s.freq).unwrap_or(0);
596        if first_freq == 0 {
597            return false;
598        }
599        self.symbols
600            .iter()
601            .all(|s| s.freq == first_freq || s.freq == first_freq - 1)
602    }
603
604    /// Writes a general (non-flat) distribution.
605    fn write_general(&self, writer: &mut BitWriter) -> Result<()> {
606        // Encode shift (we use shift=12 for maximum precision)
607        let method: u64 = 13; // shift + 1
608        let upper_bound_log = 4; // floor_log2(13)
609        let log = floor_log2(method as u32);
610
611        // Write unary prefix
612        writer.write(log as usize, (1u64 << log) - 1)?;
613        if log != upper_bound_log {
614            writer.write(1, 0)?;
615        }
616        // Write value suffix
617        writer.write(log as usize, ((1u64 << log) - 1) & method)?;
618
619        // Encode alphabet size
620        write_var_len_uint8(writer, (self.alphabet_size() - 3) as u8)?;
621
622        // For a simple implementation, encode each frequency directly
623        // Full implementation would use Huffman for bit-widths + RLE
624        for sym in &self.symbols {
625            // Encode frequency using a simple variable-length code
626            let freq = sym.freq;
627            if freq == 0 {
628                writer.write(1, 0)?;
629            } else {
630                writer.write(1, 1)?;
631                let bits = 16 - freq.leading_zeros();
632                writer.write(4, bits as u64)?;
633                if bits > 0 {
634                    writer.write(bits as usize, freq as u64)?;
635                }
636            }
637        }
638
639        Ok(())
640    }
641}
642
643/// Writes a variable-length uint8.
644fn write_var_len_uint8(writer: &mut BitWriter, n: u8) -> Result<()> {
645    if n == 0 {
646        writer.write(1, 0)?;
647    } else {
648        writer.write(1, 1)?;
649        let nbits = 8 - n.leading_zeros();
650        writer.write(3, (nbits - 1) as u64)?;
651        writer.write((nbits - 1) as usize, (n as u64) - (1u64 << (nbits - 1)))?;
652    }
653    Ok(())
654}
655
656/// Floor log2 of a value.
657#[inline]
658pub fn floor_log2_ans(n: u32) -> u32 {
659    if n == 0 { 0 } else { 31 - n.leading_zeros() }
660}
661
662#[inline]
663fn floor_log2(n: u32) -> u32 {
664    floor_log2_ans(n)
665}
666
667/// Precision calculation for frequency encoding.
668///
669/// Determines how many bits of precision to use when encoding a frequency count.
670/// Larger counts can be encoded with less precision.
671///
672/// Matches libjxl's `GetPopulationCountPrecision` from `ans_common.h`.
673pub fn get_population_count_precision(logcount: u32, shift: u32) -> u32 {
674    let logcount_i = logcount as i32;
675    let shift_i = shift as i32;
676    let r = logcount_i.min(shift_i - ((ANS_LOG_TAB_SIZE as i32 - logcount_i) >> 1));
677    r.max(0) as u32
678}
679
680/// Strategy for ANS histogram normalization.
681#[derive(Clone, Copy, Debug, PartialEq, Eq, Default)]
682pub enum ANSHistogramStrategy {
683    /// Only try a few shift values (fastest).
684    Fast,
685    /// Try every other shift value.
686    Approximate,
687    /// Try all shift values (best compression).
688    #[default]
689    Precise,
690}
691
692/// Normalized ANS histogram for encoding.
693///
694/// Contains frequency counts normalized to sum to ANS_TAB_SIZE (4096).
695#[derive(Clone, Debug)]
696pub struct ANSEncodingHistogram {
697    /// Normalized frequency counts.
698    pub counts: Vec<i32>,
699    /// Alphabet size (highest non-zero symbol + 1).
700    pub alphabet_size: usize,
701    /// Cost estimate (header + data bits).
702    pub cost: f32,
703    /// Encoding method:
704    /// - 0: flat distribution
705    /// - 1: small code (1-2 symbols)
706    /// - 2-13: shift value + 1
707    pub method: u32,
708    /// Position of the balancing bin (absorbs rounding error).
709    pub omit_pos: usize,
710    /// Number of unique symbols (for small code).
711    num_symbols: usize,
712    /// Symbol indices (for small code, up to 2).
713    symbols: [usize; 2],
714}
715
716impl ANSEncodingHistogram {
717    /// Create an empty histogram.
718    pub fn new() -> Self {
719        Self {
720            counts: Vec::new(),
721            alphabet_size: 0,
722            cost: f32::MAX,
723            method: 0,
724            omit_pos: 0,
725            num_symbols: 0,
726            symbols: [0, 0],
727        }
728    }
729
730    /// Create from a Histogram with the best normalization.
731    ///
732    /// Tries different shift values and picks the one with lowest cost.
733    /// Use `from_histogram_cached` with a precomputed `AllowedCountsCache`
734    /// when calling this in a loop to avoid repeated table construction.
735    pub fn from_histogram(
736        histo: &super::histogram::Histogram,
737        strategy: ANSHistogramStrategy,
738    ) -> Result<Self> {
739        let cache = AllowedCountsCache::new();
740        Self::from_histogram_cached(histo, strategy, &cache)
741    }
742
743    /// Create from a Histogram using precomputed allowed counts tables.
744    ///
745    /// This is the fast path — call `AllowedCountsCache::new()` once and reuse
746    /// it across all histogram normalization calls to avoid repeated allocation
747    /// and sorting of allowed counts tables.
748    pub fn from_histogram_cached(
749        histo: &super::histogram::Histogram,
750        strategy: ANSHistogramStrategy,
751        cache: &AllowedCountsCache,
752    ) -> Result<Self> {
753        if histo.total_count == 0 {
754            // Empty histogram
755            return Ok(Self {
756                counts: vec![0i32; histo.counts.len().max(1)],
757                alphabet_size: 1,
758                cost: 0.0,
759                method: 0, // Flat
760                omit_pos: 0,
761                num_symbols: 0,
762                symbols: [0, 0],
763            });
764        }
765
766        let alphabet_size = histo.alphabet_size();
767
768        // Count non-zero symbols
769        let mut num_symbols = 0;
770        let mut symbols = [0usize; 2];
771        for (i, &count) in histo.counts.iter().enumerate() {
772            if count > 0 {
773                if num_symbols < 2 {
774                    symbols[num_symbols] = i;
775                }
776                num_symbols += 1;
777            }
778        }
779
780        // Single symbol or two symbols: use small code
781        if num_symbols <= 2 {
782            let mut counts = vec![0i32; alphabet_size];
783            if num_symbols == 1 {
784                counts[symbols[0]] = ANS_TAB_SIZE as i32;
785            } else {
786                // Two symbols: proportional allocation
787                let total = histo.total_count as f64;
788                let count0 = histo.counts[symbols[0]] as f64;
789                let norm0 = ((count0 / total) * ANS_TAB_SIZE as f64).round() as i32;
790                let norm0 = norm0.clamp(1, (ANS_TAB_SIZE - 1) as i32);
791                counts[symbols[0]] = norm0;
792                counts[symbols[1]] = ANS_TAB_SIZE as i32 - norm0;
793            }
794
795            // Cost is just the header
796            let cost = if num_symbols <= 1 { 4.0 } else { 4.0 + 12.0 }; // Approximate
797
798            return Ok(Self {
799                counts,
800                alphabet_size,
801                cost,
802                method: 1, // Small code
803                omit_pos: symbols[0],
804                num_symbols,
805                symbols,
806            });
807        }
808
809        // General case: start with flat distribution as baseline
810        // libjxl always computes flat cost first (enc_ans.cc:97-102) and picks
811        // the cheaper of flat vs shift-based encoding.
812        let flat_data_cost = {
813            let log2_alpha = jxl_simd::fast_log2f(alphabet_size as f32);
814            histo.total_count as f32 * log2_alpha
815        };
816        let flat_header_cost = 2.0 + 8.0; // method=0 marker + alphabet size
817        let mut best = Self {
818            counts: {
819                let alpha = alphabet_size as u32;
820                let per = ANS_TAB_SIZE / alpha;
821                let remainder = (ANS_TAB_SIZE % alpha) as usize;
822                let mut c = vec![per as i32; alphabet_size];
823                // Distribute remainder to first symbols
824                for c in c.iter_mut().take(remainder) {
825                    *c += 1;
826                }
827                c
828            },
829            alphabet_size,
830            cost: flat_header_cost + flat_data_cost,
831            method: 0, // Flat
832            omit_pos: 0,
833            num_symbols,
834            symbols,
835        };
836
837        // Reuse a single candidate buffer across all shift iterations to avoid
838        // allocating a new vec![0i32; alphabet_size] for each shift.
839        let mut candidate_counts = vec![0i32; alphabet_size];
840
841        // Iterate shifts directly without allocating a Vec<u32>.
842        let shift_iter: &[u32] = match strategy {
843            ANSHistogramStrategy::Fast => &[0, 6, 12],
844            ANSHistogramStrategy::Approximate => &[0, 2, 4, 6, 8, 10, 12],
845            ANSHistogramStrategy::Precise => &[0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11],
846        };
847
848        for &shift in shift_iter {
849            // Reset candidate counts to zero
850            candidate_counts.fill(0);
851
852            let mut candidate = Self {
853                counts: Vec::new(), // placeholder, swapped in below
854                alphabet_size,
855                cost: f32::MAX,
856                method: shift.min(ANS_LOG_TAB_SIZE - 1) + 1,
857                omit_pos: 0,
858                num_symbols,
859                symbols,
860            };
861
862            // Swap the reusable buffer in for this iteration
863            core::mem::swap(&mut candidate.counts, &mut candidate_counts);
864
865            if candidate.rebalance_histogram_cached(histo, shift, cache.get(shift)) {
866                candidate.cost = candidate.estimate_cost(histo);
867                if candidate.cost < best.cost {
868                    // This candidate wins — take its counts and give it the old best's
869                    // buffer (or a fresh one) for the next iteration
870                    core::mem::swap(&mut candidate_counts, &mut best.counts);
871                    best = candidate;
872                    // best now has the winning counts, candidate_counts has the old flat counts
873                    // (or previous best). Resize if needed.
874                    candidate_counts.resize(alphabet_size, 0);
875                } else {
876                    // Candidate lost — reclaim its buffer
877                    core::mem::swap(&mut candidate.counts, &mut candidate_counts);
878                }
879            } else {
880                // Rebalance failed — reclaim buffer
881                core::mem::swap(&mut candidate.counts, &mut candidate_counts);
882            }
883        }
884
885        if best.cost == f32::MAX {
886            // Debug: dump histogram info
887            eprintln!(
888                "ANS rebalance FAILED: alphabet_size={}, num_symbols={}, total_count={}",
889                alphabet_size, num_symbols, histo.total_count
890            );
891            for (i, &c) in histo.counts.iter().enumerate() {
892                if c > 0 {
893                    eprintln!("  symbol {}: count={}", i, c);
894                }
895            }
896            return Err(Error::InvalidHistogram(
897                "Failed to rebalance histogram".to_string(),
898            ));
899        }
900
901        Ok(best)
902    }
903
904    /// Rebalance histogram using precomputed allowed counts table and greedy optimization.
905    /// Matches libjxl's `RebalanceHistogram` (enc_ans.cc:416-559).
906    fn rebalance_histogram_cached(
907        &mut self,
908        histo: &super::histogram::Histogram,
909        _shift: u32,
910        allowed: &[i32],
911    ) -> bool {
912        let total_count = histo.total_count;
913        if total_count == 0 {
914            return false;
915        }
916
917        let norm = ANS_TAB_SIZE as f64 / total_count as f64;
918
919        // Find remainder_pos: symbol with highest original frequency (balancing bin).
920        // Matches libjxl's remainder_pos selection.
921        let mut remainder_pos = 0;
922        let mut max_freq = 0i32;
923
924        // Bins eligible for greedy adjustment: (orig_freq, index_in_allowed, symbol_index)
925        let mut bins: Vec<(i32, usize, usize)> = Vec::with_capacity(self.alphabet_size);
926        let mut rest = ANS_TAB_SIZE as i32;
927
928        for (n, &freq) in histo.counts.iter().enumerate().take(self.alphabet_size) {
929            if freq > max_freq {
930                remainder_pos = n;
931                max_freq = freq;
932            }
933
934            if freq == 0 {
935                self.counts[n] = 0;
936                continue;
937            }
938
939            let target = freq as f64 * norm;
940            // Round and clamp to [1, ANS_TAB_SIZE-1], then snap DOWN to allowed
941            let rounded = target.round().max(1.0).min((ANS_TAB_SIZE - 1) as f64) as i32;
942            let ai = find_allowed_leq(allowed, rounded);
943            let count = allowed[ai];
944
945            self.counts[n] = count;
946            rest -= count;
947
948            // Only bins with target > 1.0 are adjustable (matches libjxl)
949            if target > 1.0 {
950                bins.push((freq, ai, n));
951            }
952        }
953
954        // Remove the balancing bin from the adjustable set
955        if let Some(pos) = bins.iter().position(|b| b.2 == remainder_pos) {
956            bins.remove(pos);
957        }
958
959        // rest now represents what the balancing bin should be.
960        // Add back remainder_pos's initial count since it's no longer adjustable.
961        rest += self.counts[remainder_pos];
962
963        // Greedy entropy optimization (libjxl enc_ans.cc:495-537).
964        // Each iteration: find the best bin to increment or decrement by one
965        // allowed-count step, with the balancing bin absorbing the difference.
966        if !bins.is_empty() {
967            let max_freq_f = max_freq as f64;
968            // Fixed-point-ish log2 scaled by a large constant for precision.
969            // Matches libjxl's lg2 table concept but using f64 directly.
970            let lg2 = |v: i32| -> f64 {
971                if v <= 0 {
972                    0.0
973                } else {
974                    jxl_simd::fast_log2f(v as f32) as f64
975                }
976            };
977
978            loop {
979                // Find the best increment step (grow a bin, shrink balancing)
980                let mut best_inc_net = 0.0f64; // must be > 0 to be taken
981                let mut best_inc_bi = None;
982
983                // Find the best decrement step (shrink a bin, grow balancing)
984                let mut best_dec_net = 0.0f64; // must be > 0 to be taken
985                let mut best_dec_bi = None;
986
987                for (bi, &(freq, ai, _bin)) in bins.iter().enumerate() {
988                    let count = allowed[ai];
989                    let freq_f = freq as f64;
990                    let lg2_count = lg2(count);
991
992                    // Try increment: move to allowed[ai - 1] (higher count)
993                    if ai > 0 {
994                        let new_count = allowed[ai - 1];
995                        let step = new_count - count;
996                        let new_rest = rest - step;
997                        if new_rest > 0 || rest >= ANS_TAB_SIZE as i32 {
998                            let gain = freq_f * (lg2(new_count) - lg2_count);
999                            let cost = if rest >= ANS_TAB_SIZE as i32 {
1000                                0.0 // tractor: pull rest down, no cost
1001                            } else if rest > 0 && new_rest > 0 {
1002                                max_freq_f * (lg2(rest) - lg2(new_rest))
1003                            } else {
1004                                f64::MAX
1005                            };
1006                            let net = gain - cost;
1007                            // Normalize by step size for fair comparison across step sizes
1008                            let step_log = floor_log2(step as u32);
1009                            let norm_net = if step_log > 0 {
1010                                net / (1u32 << step_log) as f64
1011                            } else {
1012                                net
1013                            };
1014                            if norm_net > best_inc_net {
1015                                best_inc_net = norm_net;
1016                                best_inc_bi = Some(bi);
1017                            }
1018                        }
1019                    }
1020
1021                    // Try decrement: move to allowed[ai + 1] (lower count)
1022                    if ai + 1 < allowed.len() && allowed[ai + 1] > 0 {
1023                        let new_count = allowed[ai + 1];
1024                        let step = count - new_count;
1025                        let new_rest = rest + step;
1026                        if new_rest < ANS_TAB_SIZE as i32 || rest <= 1 {
1027                            let loss = freq_f * (lg2_count - lg2(new_count));
1028                            let gain = if rest <= 1 {
1029                                f64::MAX // tractor: pull rest up, infinite gain
1030                            } else if rest > 0 && new_rest < ANS_TAB_SIZE as i32 {
1031                                max_freq_f * (lg2(new_rest) - lg2(rest))
1032                            } else {
1033                                0.0
1034                            };
1035                            let net = gain - loss;
1036                            let step_log = floor_log2(step as u32);
1037                            let norm_net = if step_log > 0 {
1038                                net / (1u32 << step_log) as f64
1039                            } else {
1040                                net
1041                            };
1042                            if norm_net > best_dec_net {
1043                                best_dec_net = norm_net;
1044                                best_dec_bi = Some(bi);
1045                            }
1046                        }
1047                    }
1048                }
1049
1050                // Prefer increment over decrement (matches libjxl)
1051                if best_inc_net > 0.0 {
1052                    if let Some(bi) = best_inc_bi {
1053                        let step = allowed[bins[bi].1 - 1] - allowed[bins[bi].1];
1054                        bins[bi].1 -= 1; // move to higher count
1055                        rest -= step;
1056                    }
1057                } else if best_dec_net > 0.0 {
1058                    if let Some(bi) = best_dec_bi {
1059                        let step = allowed[bins[bi].1] - allowed[bins[bi].1 + 1];
1060                        bins[bi].1 += 1; // move to lower count
1061                        rest += step;
1062                    }
1063                } else {
1064                    break; // No improvement possible
1065                }
1066            }
1067
1068            // Write final counts from allowed table
1069            for &(_freq, ai, bin) in &bins {
1070                self.counts[bin] = allowed[ai];
1071            }
1072
1073            // Handle omit_pos bit-width constraint (libjxl enc_ans.cc:545-551):
1074            // If an earlier bin has count >= 2048 (logcount >= 12), swap with
1075            // remainder_pos so the balancing bin can grow without bit-width issues.
1076            for n in 0..remainder_pos {
1077                if self.counts[n] >= 2048 {
1078                    self.counts[remainder_pos] = self.counts[n];
1079                    remainder_pos = n;
1080                    break;
1081                }
1082            }
1083        }
1084
1085        // Set balancing bin
1086        self.counts[remainder_pos] = rest;
1087        self.omit_pos = remainder_pos;
1088
1089        if rest <= 0 {
1090            return false;
1091        }
1092
1093        // Ensure remainder_pos is the FIRST symbol with the highest logcount.
1094        // The decoder re-derives omit_pos by scanning symbols in order and picking
1095        // the first one with the maximum logcount. If another symbol has equal or
1096        // higher logcount, the decoder picks the wrong one and decoding fails.
1097        for _ in 0..10 {
1098            let omit_logcount = floor_log2(self.counts[remainder_pos] as u32) + 1;
1099            let mut adjusted = false;
1100            for i in 0..self.alphabet_size {
1101                if i == remainder_pos || self.counts[i] <= 0 {
1102                    continue;
1103                }
1104                let logcount = floor_log2(self.counts[i] as u32) + 1;
1105                let needs_fix =
1106                    logcount > omit_logcount || (logcount == omit_logcount && i < remainder_pos);
1107                if needs_fix {
1108                    // Reduce this symbol to a representable value with lower logcount.
1109                    // Find the highest allowed count with logcount < omit_logcount
1110                    // (or <= omit_logcount for symbols after remainder_pos).
1111                    let target_logcount = if i < remainder_pos {
1112                        omit_logcount.saturating_sub(1)
1113                    } else {
1114                        omit_logcount
1115                    };
1116                    let max_value = (1i32 << target_logcount) - 1;
1117                    let new_ai = find_allowed_leq(allowed, max_value);
1118                    let new_count = allowed[new_ai].max(1);
1119                    let reduction = self.counts[i] - new_count;
1120                    if reduction > 0 {
1121                        self.counts[i] = new_count;
1122                        self.counts[remainder_pos] += reduction;
1123                        adjusted = true;
1124                    }
1125                }
1126            }
1127            if !adjusted {
1128                break;
1129            }
1130        }
1131
1132        // Final verification
1133        let omit_logcount = floor_log2(self.counts[remainder_pos] as u32) + 1;
1134        for (i, &count) in self.counts.iter().enumerate().take(self.alphabet_size) {
1135            if i == remainder_pos || count <= 0 {
1136                continue;
1137            }
1138            let logcount = floor_log2(count as u32) + 1;
1139            if logcount > omit_logcount || (logcount == omit_logcount && i < remainder_pos) {
1140                return false;
1141            }
1142        }
1143
1144        // Verify sum
1145        let sum: i32 = self.counts.iter().sum();
1146        sum == ANS_TAB_SIZE as i32
1147    }
1148
1149    /// Estimate encoding cost (header + data bits).
1150    /// Uses precise ANS cost model matching libjxl's `Cost()` (enc_ans.cc:376-380).
1151    fn estimate_cost(&self, histo: &super::histogram::Histogram) -> f32 {
1152        let header_cost = self.estimate_header_cost();
1153        let data_cost = estimate_data_bits_normalized(
1154            &histo.counts,
1155            &self.counts,
1156            histo.total_count,
1157            self.alphabet_size,
1158        ) as f32;
1159        header_cost + data_cost
1160    }
1161
1162    /// Estimate header encoding cost.
1163    fn estimate_header_cost(&self) -> f32 {
1164        if self.method == 0 {
1165            // Flat: 2 bits + alphabet size encoding
1166            2.0 + 8.0
1167        } else if self.num_symbols <= 2 {
1168            // Small code
1169            if self.num_symbols <= 1 {
1170                3.0 + 8.0 // nsym=0: marker + symbol
1171            } else {
1172                3.0 + 16.0 + 12.0 // nsym=2: marker + 2 symbols + count
1173            }
1174        } else {
1175            // General code: method encoding + alphabet + frequencies
1176            let method_bits = 4.0; // Unary + suffix for method
1177            let alphabet_bits = 8.0;
1178            let freq_bits = self.alphabet_size as f32 * 5.0; // Rough estimate
1179            method_bits + alphabet_bits + freq_bits
1180        }
1181    }
1182
1183    /// Write this histogram to a BitWriter.
1184    pub fn write(&self, writer: &mut BitWriter) -> Result<()> {
1185        if self.method == 0 {
1186            // Flat distribution
1187            writer.write(1, 0)?; // Non-small
1188            writer.write(1, 1)?; // Flat
1189            write_var_len_uint8(writer, (self.alphabet_size - 1) as u8)?;
1190            return Ok(());
1191        }
1192
1193        if self.num_symbols <= 2 {
1194            // Small code
1195            writer.write(1, 1)?; // Small tree marker
1196            if self.num_symbols == 0 {
1197                writer.write(1, 0)?;
1198                write_var_len_uint8(writer, 0)?;
1199            } else {
1200                writer.write(1, (self.num_symbols - 1) as u64)?;
1201                for i in 0..self.num_symbols {
1202                    write_var_len_uint8(writer, self.symbols[i] as u8)?;
1203                }
1204                if self.num_symbols == 2 {
1205                    writer.write(
1206                        ANS_LOG_TAB_SIZE as usize,
1207                        self.counts[self.symbols[0]] as u64,
1208                    )?;
1209                }
1210            }
1211            return Ok(());
1212        }
1213
1214        // General code
1215        self.write_general(writer)
1216    }
1217
1218    /// Write general (non-flat, non-small) histogram.
1219    ///
1220    /// Matches the format expected by jxl-rs `decode_dist_complex()`.
1221    fn write_general(&self, writer: &mut BitWriter) -> Result<()> {
1222        writer.write(1, 0)?; // Non-small
1223        writer.write(1, 0)?; // Non-flat
1224
1225        // Encode shift using unary + suffix (method = shift + 1)
1226        // Format: len ones, then 0 (unless at max), then len suffix bits
1227        let shift = (self.method - 1) as i32;
1228        let shift_val = (shift + 1) as u32; // shift+1 is stored, range 1-13
1229
1230        // Determine unary length
1231        let mut len = 0u32;
1232        while len < 3 && shift_val >= (1u32 << (len + 1)) {
1233            len += 1;
1234        }
1235
1236        // Write unary prefix (len ones)
1237        for _ in 0..len {
1238            writer.write(1, 1)?;
1239        }
1240        // Write terminating 0 if len < 3
1241        if len < 3 {
1242            writer.write(1, 0)?;
1243        }
1244        // Write suffix bits
1245        if len > 0 {
1246            let suffix = shift_val - (1u32 << len);
1247            writer.write(len as usize, suffix as u64)?;
1248        }
1249
1250        // Encode alphabet size - 3
1251        if self.alphabet_size < 3 {
1252            return Err(Error::InvalidHistogram(
1253                "General histogram needs at least 3 symbols".to_string(),
1254            ));
1255        }
1256        write_var_len_uint8(writer, (self.alphabet_size - 3) as u8)?;
1257
1258        // Pre-compute logcounts for all symbols
1259        let logcounts: Vec<u32> = (0..self.alphabet_size)
1260            .map(|i| {
1261                let count = self.counts[i];
1262                if count <= 0 {
1263                    0
1264                } else {
1265                    floor_log2(count as u32) + 1
1266                }
1267            })
1268            .collect();
1269
1270        // Pre-compute RLE: for each position i, same[i] = number of consecutive
1271        // symbols starting at i+1 that have the same actual count as i.
1272        // The decoder fills RLE positions with prev_dist (the actual count), so
1273        // all symbols in a run must have identical normalized counts (not just logcounts).
1274        // Constraints (libjxl enc_ans.cc:257-273):
1275        // - RLE range must not include omit_pos
1276        // - RLE marker must not appear at omit_pos+1
1277        let mut same = vec![0usize; self.alphabet_size];
1278        #[allow(clippy::needless_range_loop)]
1279        for i in 0..self.alphabet_size {
1280            if i == self.omit_pos {
1281                continue;
1282            }
1283            let mut run = 0;
1284            let mut j = i + 1;
1285            while j < self.alphabet_size && self.counts[j] == self.counts[i] {
1286                if j == self.omit_pos {
1287                    break; // Can't include omit_pos in RLE range
1288                }
1289                run += 1;
1290                j += 1;
1291            }
1292            same[i] = run;
1293        }
1294
1295        // Encode log-frequency values with RLE (libjxl enc_ans.cc:300-309).
1296        // The decoder determines omit_pos as the first symbol with highest logcount.
1297        const MIN_REPS: usize = 4; // Minimum repeat count (decoder reads value+4)
1298        let mut i = 0;
1299        while i < self.alphabet_size {
1300            // Write the logcount using fixed prefix code
1301            let (nbits, code) = LOGCOUNT_PREFIX_CODE[logcounts[i] as usize];
1302            writer.write(nbits as usize, code as u64)?;
1303
1304            // If 4+ following symbols have the same logcount, use RLE.
1305            // But don't place RLE marker at omit_pos+1 (decoder rejects this).
1306            if same[i] >= MIN_REPS && i + 1 != self.omit_pos + 1 {
1307                let (rle_nbits, rle_code) = LOGCOUNT_PREFIX_CODE[RLE_MARKER_SYM as usize];
1308                writer.write(rle_nbits as usize, rle_code as u64)?;
1309                write_var_len_uint8(writer, (same[i] - MIN_REPS) as u8)?;
1310                i += same[i]; // Skip the repeated symbols
1311            }
1312            i += 1;
1313        }
1314
1315        // Build set of RLE-covered positions. The decoder skips precision bits
1316        // for these symbols (the `continue` in the RLE range handler).
1317        let mut rle_covered = vec![false; self.alphabet_size];
1318        {
1319            let mut i = 0;
1320            while i < self.alphabet_size {
1321                if same[i] >= MIN_REPS && i + 1 != self.omit_pos + 1 {
1322                    // Positions i+1 through i+same[i] are RLE-covered
1323                    for item in rle_covered.iter_mut().take(i + same[i] + 1).skip(i + 1) {
1324                        *item = true;
1325                    }
1326                    i += same[i];
1327                }
1328                i += 1;
1329            }
1330        }
1331
1332        // Now write precision bits for each non-zero, non-omit symbol with logcount > 1.
1333        // Skip RLE-covered positions (decoder skips precision bits for those).
1334        for i in 0..self.alphabet_size {
1335            if i == self.omit_pos || rle_covered[i] {
1336                continue;
1337            }
1338
1339            let count = self.counts[i];
1340            if count <= 0 {
1341                continue;
1342            }
1343
1344            let logcount = logcounts[i];
1345            if logcount <= 1 {
1346                // logcount=1 means freq=1, no precision bits needed
1347                continue;
1348            }
1349
1350            // zeros = logcount - 1 (the log2 of the frequency)
1351            let zeros = (logcount - 1) as i32;
1352            // bitcount = shift - (12 - zeros) / 2, clamped to [0, zeros]
1353            let bitcount = (shift - ((ANS_LOG_TAB_SIZE as i32 - zeros) >> 1)).clamp(0, zeros);
1354
1355            if bitcount > 0 {
1356                // The value stored is: freq = (1 << zeros) + (extra << (zeros - bitcount))
1357                // So: extra = (freq - (1 << zeros)) >> (zeros - bitcount)
1358                let base = 1i32 << zeros;
1359                let extra = ((count - base) >> (zeros - bitcount)) as u32;
1360                writer.write(bitcount as usize, extra as u64)?;
1361            }
1362        }
1363
1364        Ok(())
1365    }
1366}
1367
1368impl Default for ANSEncodingHistogram {
1369    fn default() -> Self {
1370        Self::new()
1371    }
1372}
1373
1374/// Encodes tokens using ANS.
1375pub fn encode_tokens_ans(
1376    tokens: &[(u32, u32)], // (context, value) pairs
1377    distributions: &[AnsDistribution],
1378    context_map: &[usize],
1379    writer: &mut BitWriter,
1380) -> Result<()> {
1381    let mut encoder = AnsEncoder::new();
1382
1383    // Process tokens in reverse order (ANS requirement)
1384    for &(context, value) in tokens.iter().rev() {
1385        let dist_idx = context_map.get(context as usize).copied().unwrap_or(0);
1386        if let Some(dist) = distributions.get(dist_idx)
1387            && let Some(info) = dist.get(value as usize)
1388        {
1389            encoder.put_symbol(info);
1390        }
1391    }
1392
1393    encoder.finalize(writer)
1394}
1395
1396#[cfg(test)]
1397mod tests {
1398    use super::*;
1399    use crate::entropy_coding::histogram::Histogram;
1400
1401    #[test]
1402    fn test_ans_encoding_histogram_single_symbol() {
1403        let h = Histogram::from_counts(&[100, 0, 0, 0]);
1404        let encoded = ANSEncodingHistogram::from_histogram(&h, ANSHistogramStrategy::Fast).unwrap();
1405
1406        assert_eq!(encoded.num_symbols, 1);
1407        assert_eq!(encoded.method, 1); // Small code
1408        assert_eq!(encoded.counts[0], ANS_TAB_SIZE as i32);
1409        assert!(encoded.cost < 100.0); // Header only
1410    }
1411
1412    #[test]
1413    fn test_ans_encoding_histogram_two_symbols() {
1414        let h = Histogram::from_counts(&[100, 100, 0, 0]);
1415        let encoded = ANSEncodingHistogram::from_histogram(&h, ANSHistogramStrategy::Fast).unwrap();
1416
1417        assert_eq!(encoded.num_symbols, 2);
1418        assert_eq!(encoded.method, 1); // Small code
1419        // Should split roughly 50/50
1420        let sum: i32 = encoded.counts.iter().sum();
1421        assert_eq!(sum, ANS_TAB_SIZE as i32);
1422        assert!(encoded.counts[0] > 0);
1423        assert!(encoded.counts[1] > 0);
1424    }
1425
1426    #[test]
1427    fn test_ans_encoding_histogram_general() {
1428        let h = Histogram::from_counts(&[100, 50, 25, 10, 5, 3, 2, 1]);
1429        let encoded = ANSEncodingHistogram::from_histogram(&h, ANSHistogramStrategy::Fast).unwrap();
1430
1431        // Should use general code (more than 2 symbols)
1432        assert!(encoded.method >= 2 || encoded.method == 0);
1433
1434        // Sum should be exactly ANS_TAB_SIZE
1435        let sum: i32 = encoded.counts.iter().sum();
1436        assert_eq!(sum, ANS_TAB_SIZE as i32);
1437
1438        // All non-zero original counts should have non-zero normalized counts
1439        for (i, &orig) in h.counts.iter().enumerate() {
1440            if orig > 0 {
1441                assert!(
1442                    encoded.counts.get(i).copied().unwrap_or(0) > 0,
1443                    "Symbol {} had count {} but normalized to 0",
1444                    i,
1445                    orig
1446                );
1447            }
1448        }
1449    }
1450
1451    #[test]
1452    fn test_ans_encoding_histogram_empty() {
1453        let h = Histogram::new();
1454        let encoded = ANSEncodingHistogram::from_histogram(&h, ANSHistogramStrategy::Fast).unwrap();
1455
1456        assert_eq!(encoded.cost, 0.0);
1457        assert_eq!(encoded.method, 0); // Flat
1458    }
1459
1460    #[test]
1461    fn test_get_population_count_precision() {
1462        // logcount=0, any shift: precision=0
1463        assert_eq!(get_population_count_precision(0, 12), 0);
1464
1465        // logcount=12, shift=12: min(12, 12-(0/2)) = 12
1466        assert_eq!(get_population_count_precision(12, 12), 12);
1467
1468        // logcount=6, shift=6: min(6, 6-3) = 3
1469        assert_eq!(get_population_count_precision(6, 6), 3);
1470
1471        // logcount=1, shift=0: min(1, 0-(11/2)) = min(1, -5) = 0 (clamped)
1472        assert_eq!(get_population_count_precision(1, 0), 0);
1473    }
1474
1475    #[test]
1476    fn test_ans_encoding_histogram_write() {
1477        let h = Histogram::from_counts(&[100, 0, 0, 0]);
1478        let encoded = ANSEncodingHistogram::from_histogram(&h, ANSHistogramStrategy::Fast).unwrap();
1479
1480        let mut writer = BitWriter::new();
1481        encoded.write(&mut writer).unwrap();
1482
1483        let bytes = writer.finish_with_padding();
1484        assert!(!bytes.is_empty());
1485    }
1486
1487    #[test]
1488    fn test_flat_distribution() {
1489        let dist = AnsDistribution::flat(16).unwrap();
1490        assert_eq!(dist.alphabet_size(), 16);
1491
1492        // All frequencies should be 256 (4096 / 16)
1493        for sym in &dist.symbols {
1494            assert_eq!(sym.freq, 256);
1495        }
1496    }
1497
1498    #[test]
1499    fn test_from_frequencies() {
1500        let freqs = vec![100, 200, 300, 400];
1501        let dist = AnsDistribution::from_frequencies(&freqs).unwrap();
1502        assert_eq!(dist.alphabet_size(), 4);
1503
1504        // Total should be ANS_TAB_SIZE
1505        let total: u32 = dist.symbols.iter().map(|s| s.freq as u32).sum();
1506        assert_eq!(total, ANS_TAB_SIZE);
1507    }
1508
1509    #[test]
1510    fn test_ans_encoder_basic() {
1511        let dist = AnsDistribution::flat(4).unwrap();
1512        let mut encoder = AnsEncoder::new();
1513
1514        // Encode a few symbols
1515        encoder.put_symbol(&dist.symbols[0]);
1516        encoder.put_symbol(&dist.symbols[1]);
1517        encoder.put_symbol(&dist.symbols[2]);
1518
1519        // State should have changed from initial
1520        assert_ne!(encoder.state(), ANS_SIGNATURE << 16);
1521    }
1522
1523    #[test]
1524    fn test_reverse_map() {
1525        let dist = AnsDistribution::flat(4).unwrap();
1526
1527        // Each symbol should have freq entries in reverse_map
1528        for sym in &dist.symbols {
1529            assert_eq!(sym.reverse_map.len(), sym.freq as usize);
1530        }
1531
1532        // All positions 0..4096 should be covered exactly once
1533        let mut covered = vec![false; ANS_TAB_SIZE as usize];
1534        for sym in &dist.symbols {
1535            for &pos in &sym.reverse_map {
1536                assert!(!covered[pos as usize], "position {} covered twice", pos);
1537                covered[pos as usize] = true;
1538            }
1539        }
1540        assert!(covered.iter().all(|&c| c), "not all positions covered");
1541    }
1542
1543    #[test]
1544    fn test_write_distribution() {
1545        let dist = AnsDistribution::flat(16).unwrap();
1546        let mut writer = BitWriter::new();
1547        dist.write(&mut writer).unwrap();
1548
1549        let bytes = writer.finish_with_padding();
1550        // Should produce some output
1551        assert!(!bytes.is_empty());
1552    }
1553
1554    #[test]
1555    fn test_ans_roundtrip_manual() {
1556        // Create a simple flat distribution
1557        let dist = AnsDistribution::flat(2).unwrap();
1558
1559        println!("Distribution: {} symbols", dist.alphabet_size());
1560        for (i, sym) in dist.symbols.iter().enumerate() {
1561            println!("  Symbol {}: freq={}", i, sym.freq);
1562        }
1563
1564        // Encode symbol 0
1565        let mut encoder = AnsEncoder::new();
1566        let initial_state = encoder.state();
1567        println!("\nInitial state: 0x{:08x}", initial_state);
1568        assert_eq!(initial_state, 0x130000, "Initial state should be 0x130000");
1569
1570        let info = &dist.symbols[0];
1571        encoder.put_symbol(info);
1572        let encoded_state = encoder.state();
1573        println!("After encoding symbol 0: state=0x{:08x}", encoded_state);
1574
1575        // Now manually decode to verify
1576        let idx = encoded_state & 0xFFF;
1577        println!("Decode: idx = {}", idx);
1578
1579        // For flat distribution of 2, each has freq 2048
1580        // Symbol 0: cumul=0, freq=2048 -> positions [0, 2048)
1581        // Symbol 1: cumul=2048, freq=2048 -> positions [2048, 4096)
1582        let decoded_symbol = if idx < 2048 { 0 } else { 1 };
1583        let offset_in_symbol = if idx < 2048 { idx } else { idx - 2048 };
1584        let freq = 2048u32;
1585
1586        println!("Decoded symbol: {}", decoded_symbol);
1587        println!("Offset in symbol: {}", offset_in_symbol);
1588
1589        // The decoder does: next_state = (state >> 12) * freq + offset
1590        let quotient = encoded_state >> 12;
1591        let next_state = quotient * freq + offset_in_symbol;
1592        println!(
1593            "next_state = {} * {} + {} = 0x{:08x}",
1594            quotient, freq, offset_in_symbol, next_state
1595        );
1596
1597        // The next_state should be the initial state (0x130000)
1598        assert_eq!(next_state, 0x130000, "Decoded state should be 0x130000");
1599        assert_eq!(decoded_symbol, 0, "Decoded symbol should be 0");
1600    }
1601
1602    #[test]
1603    fn test_ans_roundtrip_multiple_symbols() {
1604        use crate::bit_writer::BitWriter;
1605        use crate::entropy_coding::ans_decode::{AnsHistogram, AnsReader, BitReader};
1606
1607        // Test encoding multiple symbols and verify they can be decoded
1608        // using the jxl-rs compatible decoder (alias table method)
1609
1610        // Create a flat distribution with 4 symbols (each freq = 1024)
1611        let counts = [1024i32, 1024, 1024, 1024];
1612        let dist = AnsDistribution::from_normalized_counts(&counts).unwrap();
1613
1614        let symbols_to_encode: Vec<usize> = vec![0, 1, 2, 3, 0, 1];
1615        println!(
1616            "Encoding {} symbols: {:?}",
1617            symbols_to_encode.len(),
1618            symbols_to_encode
1619        );
1620
1621        // Encode in reverse order (as ANS requires)
1622        let mut encoder = AnsEncoder::new();
1623        for &sym in symbols_to_encode.iter().rev() {
1624            encoder.put_symbol(&dist.symbols[sym]);
1625        }
1626
1627        println!("Final state after encoding: 0x{:08x}", encoder.state());
1628
1629        // Finalize encoder to bitstream
1630        let mut writer = BitWriter::new();
1631        encoder.finalize(&mut writer).unwrap();
1632        let encoded_bytes = writer.finish_with_padding();
1633        println!("Encoded bytes: {:02x?}", encoded_bytes);
1634
1635        // Build decoder histogram by writing and reading back
1636        let ans_histo = ANSEncodingHistogram::from_histogram(
1637            &Histogram::from_counts(&counts),
1638            ANSHistogramStrategy::Precise,
1639        )
1640        .unwrap();
1641        let mut hist_writer = BitWriter::new();
1642        ans_histo.write(&mut hist_writer).unwrap();
1643        let hist_bytes = hist_writer.finish_with_padding();
1644
1645        let mut hist_br = BitReader::new(&hist_bytes);
1646        let decoded_hist = AnsHistogram::decode(&mut hist_br, 6).unwrap();
1647
1648        println!(
1649            "Decoded histogram frequencies: {:?}",
1650            &decoded_hist.frequencies[..4]
1651        );
1652
1653        // Decode using jxl-rs compatible decoder
1654        let mut br = BitReader::new(&encoded_bytes);
1655        let mut ans_reader = AnsReader::init(&mut br).unwrap();
1656
1657        println!("Decoding:");
1658        let mut decoded = Vec::new();
1659        for i in 0..symbols_to_encode.len() {
1660            let symbol = decoded_hist.read(&mut br, &mut ans_reader.0) as usize;
1661            println!(
1662                "  step {}: symbol={}, state=0x{:08x}",
1663                i, symbol, ans_reader.0
1664            );
1665            decoded.push(symbol);
1666        }
1667
1668        println!("Original: {:?}", symbols_to_encode);
1669        println!("Decoded:  {:?}", decoded);
1670        println!("Final state: 0x{:08x}", ans_reader.0);
1671
1672        assert_eq!(
1673            decoded, symbols_to_encode,
1674            "Decoded symbols should match original"
1675        );
1676        assert!(
1677            ans_reader.check_final_state().is_ok(),
1678            "Final state should be 0x130000, got 0x{:08x}",
1679            ans_reader.0
1680        );
1681    }
1682
1683    #[test]
1684    fn test_ans_histogram_write_decode_roundtrip() {
1685        use crate::bit_writer::BitWriter;
1686        use crate::entropy_coding::histogram::Histogram;
1687
1688        // Create a histogram with several symbols
1689        let histo = Histogram::from_counts(&[100, 50, 25, 10]);
1690
1691        let encoded =
1692            ANSEncodingHistogram::from_histogram(&histo, ANSHistogramStrategy::Precise).unwrap();
1693
1694        println!("Histogram: {:?}", histo.counts);
1695        println!("Encoded counts: {:?}", encoded.counts);
1696        println!(
1697            "Method: {}, alphabet_size: {}, omit_pos: {}",
1698            encoded.method, encoded.alphabet_size, encoded.omit_pos
1699        );
1700
1701        // Verify sum is 4096
1702        let sum: i32 = encoded.counts.iter().sum();
1703        assert_eq!(sum, ANS_TAB_SIZE as i32, "Sum should be 4096");
1704
1705        // Write to bitstream
1706        let mut writer = BitWriter::new();
1707        encoded.write(&mut writer).unwrap();
1708        let bytes = writer.finish_with_padding();
1709
1710        println!("Encoded histogram: {} bytes", bytes.len());
1711        println!("Bytes: {:02x?}", bytes);
1712    }
1713}