Skip to main content

oximedia_codec/jpegxl/
entropy.rs

1//! ANS (Asymmetric Numeral Systems) entropy coding for JPEG-XL.
2//!
3//! JPEG-XL uses rANS (range ANS) for entropy coding of transform coefficients
4//! and modular residuals. This module implements both encoding and decoding
5//! with distribution tables.
6//!
7//! ANS provides near-optimal compression (close to entropy limit) with
8//! LIFO (stack) semantics: symbols must be encoded in reverse order and
9//! decoded in forward order.
10//!
11//! This implementation uses 16-bit word-based renormalization for simplicity
12//! and correctness.
13
14use crate::error::{CodecError, CodecResult};
15
16/// Default log2 of the ANS table size.
17const DEFAULT_LOG_TABLE_SIZE: u8 = 10;
18
19/// Bits in a renormalization word.
20const RENORM_WORD_BITS: u32 = 16;
21
22/// ANS probability distribution table.
23///
24/// Maps symbols to frequency counts. The sum of all frequencies must equal
25/// `1 << log_table_size` (the table size).
26#[derive(Clone, Debug)]
27pub struct AnsDistribution {
28    /// Symbol values in the distribution.
29    pub symbols: Vec<u16>,
30    /// Frequency (probability count) for each symbol.
31    pub frequencies: Vec<u32>,
32    /// Cumulative frequency table (prefix sums of frequencies).
33    pub cumulative: Vec<u32>,
34    /// Log2 of the total frequency table size.
35    pub log_table_size: u8,
36}
37
38impl AnsDistribution {
39    /// Create a new distribution from symbol frequencies.
40    ///
41    /// Frequencies are normalized so they sum to `1 << log_table_size`.
42    pub fn new(symbols: Vec<u16>, frequencies: Vec<u32>, log_table_size: u8) -> CodecResult<Self> {
43        if symbols.len() != frequencies.len() {
44            return Err(CodecError::InvalidParameter(
45                "Symbol and frequency vectors must have the same length".into(),
46            ));
47        }
48        if symbols.is_empty() {
49            return Err(CodecError::InvalidParameter(
50                "Distribution must have at least one symbol".into(),
51            ));
52        }
53
54        let total: u32 = frequencies.iter().sum();
55        if total == 0 {
56            return Err(CodecError::InvalidParameter(
57                "Total frequency must be non-zero".into(),
58            ));
59        }
60
61        let table_size = 1u32 << log_table_size;
62
63        // Normalize frequencies to sum to table_size
64        let mut normalized: Vec<u32> = frequencies
65            .iter()
66            .map(|&f| {
67                if f == 0 {
68                    0
69                } else {
70                    let n = (f as u64 * table_size as u64 / total as u64) as u32;
71                    if n == 0 {
72                        1
73                    } else {
74                        n
75                    }
76                }
77            })
78            .collect();
79
80        // Adjust to ensure sum equals table_size exactly
81        let current_sum: u32 = normalized.iter().sum();
82        if current_sum != table_size {
83            let diff = table_size as i64 - current_sum as i64;
84            if let Some(max_idx) = normalized
85                .iter()
86                .enumerate()
87                .filter(|(_, &f)| f > 0)
88                .max_by_key(|(_, &f)| f)
89                .map(|(i, _)| i)
90            {
91                let adjusted = normalized[max_idx] as i64 + diff;
92                if adjusted > 0 {
93                    normalized[max_idx] = adjusted as u32;
94                }
95            }
96        }
97
98        // Build cumulative table
99        let mut cumulative = Vec::with_capacity(normalized.len() + 1);
100        cumulative.push(0);
101        let mut sum = 0u32;
102        for &f in &normalized {
103            sum += f;
104            cumulative.push(sum);
105        }
106
107        Ok(Self {
108            symbols,
109            frequencies: normalized,
110            cumulative,
111            log_table_size,
112        })
113    }
114
115    /// Total table size (sum of all frequencies).
116    pub fn table_size(&self) -> u32 {
117        1u32 << self.log_table_size
118    }
119
120    /// Number of symbols in the distribution.
121    pub fn num_symbols(&self) -> usize {
122        self.symbols.len()
123    }
124
125    /// Look up a symbol by its cumulative frequency position.
126    pub fn lookup(&self, value: u32) -> CodecResult<(usize, u32, u32)> {
127        let mut lo = 0usize;
128        let mut hi = self.symbols.len();
129        while lo < hi {
130            let mid = lo + (hi - lo) / 2;
131            if self.cumulative[mid + 1] <= value {
132                lo = mid + 1;
133            } else {
134                hi = mid;
135            }
136        }
137        if lo >= self.symbols.len() {
138            return Err(CodecError::InvalidBitstream(format!(
139                "ANS lookup failed: value {value} out of range"
140            )));
141        }
142        Ok((lo, self.cumulative[lo], self.frequencies[lo]))
143    }
144
145    /// Find the index of a symbol in the distribution.
146    fn find_symbol(&self, symbol: u16) -> CodecResult<usize> {
147        self.symbols
148            .iter()
149            .position(|&s| s == symbol)
150            .ok_or_else(|| {
151                CodecError::InvalidParameter(format!("Symbol {symbol} not found in distribution"))
152            })
153    }
154}
155
156/// Build a uniform distribution for N symbols.
157pub fn uniform_distribution(n: u16) -> CodecResult<AnsDistribution> {
158    if n == 0 {
159        return Err(CodecError::InvalidParameter(
160            "Cannot create uniform distribution with 0 symbols".into(),
161        ));
162    }
163    let symbols: Vec<u16> = (0..n).collect();
164    let freq = vec![1u32; n as usize];
165    AnsDistribution::new(symbols, freq, DEFAULT_LOG_TABLE_SIZE)
166}
167
168/// Build a distribution from observed frequency counts.
169pub fn distribution_from_counts(
170    counts: &[u32],
171    log_table_size: u8,
172) -> CodecResult<AnsDistribution> {
173    let mut symbols = Vec::new();
174    let mut frequencies = Vec::new();
175
176    for (i, &count) in counts.iter().enumerate() {
177        if count > 0 {
178            symbols.push(i as u16);
179            frequencies.push(count);
180        }
181    }
182
183    if symbols.is_empty() {
184        symbols.push(0);
185        frequencies.push(1);
186    }
187
188    AnsDistribution::new(symbols, frequencies, log_table_size)
189}
190
191/// rANS decoder.
192///
193/// Decodes symbols from a stream of 16-bit words.
194/// Stream format: [state: u32 LE] [word_count: u32 LE] [words: u16 LE...]
195///
196/// Words are read in FIFO order (first word in stream is first word consumed).
197pub struct AnsDecoder<'a> {
198    state: u32,
199    data: &'a [u8],
200    /// Current read position in data (after the 8-byte header).
201    word_pos: usize,
202}
203
204impl<'a> AnsDecoder<'a> {
205    /// Create a new ANS decoder from encoded data.
206    pub fn new(data: &'a [u8]) -> CodecResult<Self> {
207        if data.len() < 8 {
208            return Err(CodecError::InvalidBitstream("ANS data too short".into()));
209        }
210        let state = u32::from_le_bytes([data[0], data[1], data[2], data[3]]);
211        // word_count at bytes 4..8 (we don't strictly need it, just read sequentially)
212        Ok(Self {
213            state,
214            data,
215            word_pos: 8,
216        })
217    }
218
219    /// Read one 16-bit renormalization word.
220    fn read_word(&mut self) -> u16 {
221        if self.word_pos + 1 < self.data.len() {
222            let w = u16::from_le_bytes([self.data[self.word_pos], self.data[self.word_pos + 1]]);
223            self.word_pos += 2;
224            w
225        } else {
226            0
227        }
228    }
229
230    /// Decode a single symbol using the given distribution.
231    pub fn decode_symbol(&mut self, dist: &AnsDistribution) -> CodecResult<u16> {
232        let table_size = dist.table_size();
233        let mask = table_size - 1;
234
235        let slot = self.state & mask;
236        let (idx, start, freq) = dist.lookup(slot)?;
237        let symbol = dist.symbols[idx];
238
239        // Update state: state = freq * (state >> log_table_size) + slot - start
240        self.state = freq * (self.state >> dist.log_table_size) + slot - start;
241
242        // Renormalize: if state dropped below table_size, read a 16-bit word
243        if self.state < table_size {
244            let word = self.read_word() as u32;
245            self.state = (self.state << RENORM_WORD_BITS) | word;
246        }
247
248        Ok(symbol)
249    }
250}
251
252/// rANS encoder.
253///
254/// Encodes symbols using probability distributions. Due to LIFO semantics,
255/// symbols must be encoded in reverse of the desired decode order.
256///
257/// Uses 16-bit word-based renormalization. The encoder accumulates words
258/// into a buffer. On finish, it outputs:
259/// [state: u32 LE] [word_count: u32 LE] [words in reverse order: u16 LE...]
260pub struct AnsEncoder {
261    state: u32,
262    /// Renormalization words accumulated during encoding.
263    words: Vec<u16>,
264    log_table_size: u8,
265}
266
267impl AnsEncoder {
268    /// Create a new ANS encoder.
269    pub fn new() -> Self {
270        let log_table_size = DEFAULT_LOG_TABLE_SIZE;
271        let table_size = 1u32 << log_table_size;
272        Self {
273            state: table_size, // initial state = table_size (lower bound of valid range)
274            words: Vec::new(),
275            log_table_size,
276        }
277    }
278
279    /// Encode a single symbol.
280    pub fn encode_symbol(&mut self, symbol: u16, dist: &AnsDistribution) -> CodecResult<()> {
281        let idx = dist.find_symbol(symbol)?;
282        let start = dist.cumulative[idx];
283        let freq = dist.frequencies[idx];
284
285        if freq == 0 {
286            return Err(CodecError::InvalidParameter(format!(
287                "Symbol {symbol} has zero frequency"
288            )));
289        }
290
291        let table_size = dist.table_size();
292
293        // Renormalize: output 16-bit words while state is too large.
294        // After renorm, state must be in [freq, freq * (1 << RENORM_WORD_BITS))
295        // so that the encoding step produces a state in [table_size, table_size * (1 << RENORM_WORD_BITS))
296        let upper_bound = freq << RENORM_WORD_BITS;
297        while self.state >= upper_bound {
298            self.words.push(self.state as u16);
299            self.state >>= RENORM_WORD_BITS;
300        }
301
302        // Encode: state = table_size * (state / freq) + (state % freq) + start
303        self.state = table_size * (self.state / freq) + (self.state % freq) + start;
304
305        Ok(())
306    }
307
308    /// Finish encoding and return the encoded byte buffer.
309    pub fn finish(self) -> Vec<u8> {
310        let word_count = self.words.len() as u32;
311        let mut output = Vec::with_capacity(8 + self.words.len() * 2);
312
313        // Write final state
314        output.extend_from_slice(&self.state.to_le_bytes());
315        // Write word count
316        output.extend_from_slice(&word_count.to_le_bytes());
317        // Write words in reverse order (LIFO -> FIFO for decoder)
318        for &word in self.words.iter().rev() {
319            output.extend_from_slice(&word.to_le_bytes());
320        }
321
322        output
323    }
324}
325
326impl Default for AnsEncoder {
327    fn default() -> Self {
328        Self::new()
329    }
330}
331
332#[cfg(test)]
333mod tests {
334    use super::*;
335
336    #[test]
337    #[ignore]
338    fn test_uniform_distribution() {
339        let dist = uniform_distribution(4).expect("ok");
340        assert_eq!(dist.num_symbols(), 4);
341        assert_eq!(dist.table_size(), 1 << DEFAULT_LOG_TABLE_SIZE);
342        let expected = dist.table_size() / 4;
343        for &f in &dist.frequencies {
344            assert!((f as i64 - expected as i64).unsigned_abs() <= 1);
345        }
346    }
347
348    #[test]
349    #[ignore]
350    fn test_distribution_from_counts() {
351        let counts = [10u32, 20, 30, 0, 40];
352        let dist = distribution_from_counts(&counts, 10).expect("ok");
353        assert_eq!(dist.num_symbols(), 4);
354        assert_eq!(dist.symbols, vec![0, 1, 2, 4]);
355    }
356
357    #[test]
358    #[ignore]
359    fn test_distribution_cumulative() {
360        let symbols = vec![0, 1, 2];
361        let freqs = vec![256, 512, 256];
362        let dist = AnsDistribution::new(symbols, freqs, 10).expect("ok");
363        assert_eq!(dist.cumulative[0], 0);
364        assert_eq!(
365            *dist.cumulative.last().expect("has last"),
366            dist.table_size()
367        );
368    }
369
370    #[test]
371    #[ignore]
372    fn test_distribution_lookup() {
373        let symbols = vec![0, 1];
374        let freqs = vec![512, 512];
375        let dist = AnsDistribution::new(symbols, freqs, 10).expect("ok");
376
377        let (idx, start, freq) = dist.lookup(0).expect("ok");
378        assert_eq!(idx, 0);
379        assert_eq!(start, 0);
380        assert!(freq > 0);
381
382        let (idx, _start, _freq) = dist.lookup(dist.table_size() - 1).expect("ok");
383        assert_eq!(idx, 1);
384    }
385
386    #[test]
387    #[ignore]
388    fn test_ans_roundtrip_single_symbol() {
389        let dist = uniform_distribution(4).expect("ok");
390
391        let mut encoder = AnsEncoder::new();
392        encoder.encode_symbol(2, &dist).expect("ok");
393        let encoded = encoder.finish();
394
395        let mut decoder = AnsDecoder::new(&encoded).expect("ok");
396        let decoded = decoder.decode_symbol(&dist).expect("ok");
397        assert_eq!(decoded, 2);
398    }
399
400    #[test]
401    #[ignore]
402    fn test_ans_roundtrip_sequence() {
403        let dist = uniform_distribution(8).expect("ok");
404        let symbols_to_encode: Vec<u16> = vec![0, 3, 7, 1, 5, 2, 6, 4];
405
406        // Encode in reverse order (ANS is LIFO)
407        let mut encoder = AnsEncoder::new();
408        for &sym in symbols_to_encode.iter().rev() {
409            encoder.encode_symbol(sym, &dist).expect("ok");
410        }
411        let encoded = encoder.finish();
412
413        // Decode in forward order
414        let mut decoder = AnsDecoder::new(&encoded).expect("ok");
415        for &expected in &symbols_to_encode {
416            let decoded = decoder.decode_symbol(&dist).expect("ok");
417            assert_eq!(decoded, expected, "ANS roundtrip mismatch");
418        }
419    }
420
421    #[test]
422    #[ignore]
423    fn test_ans_roundtrip_skewed_distribution() {
424        let symbols = vec![0, 1, 2, 3];
425        let freqs = vec![700, 200, 80, 20];
426        let dist = AnsDistribution::new(symbols, freqs, 10).expect("ok");
427
428        let test_seq: Vec<u16> = vec![0, 0, 0, 1, 0, 2, 0, 0, 3, 0, 1];
429
430        let mut encoder = AnsEncoder::new();
431        for &sym in test_seq.iter().rev() {
432            encoder.encode_symbol(sym, &dist).expect("ok");
433        }
434        let encoded = encoder.finish();
435
436        let mut decoder = AnsDecoder::new(&encoded).expect("ok");
437        for &expected in &test_seq {
438            let decoded = decoder.decode_symbol(&dist).expect("ok");
439            assert_eq!(decoded, expected);
440        }
441    }
442
443    #[test]
444    #[ignore]
445    fn test_ans_roundtrip_repeated_symbol() {
446        let dist = uniform_distribution(4).expect("ok");
447        let symbols: Vec<u16> = vec![1, 1, 1, 1, 1];
448
449        let mut encoder = AnsEncoder::new();
450        for &sym in symbols.iter().rev() {
451            encoder.encode_symbol(sym, &dist).expect("ok");
452        }
453        let encoded = encoder.finish();
454
455        let mut decoder = AnsDecoder::new(&encoded).expect("ok");
456        for &expected in &symbols {
457            let decoded = decoder.decode_symbol(&dist).expect("ok");
458            assert_eq!(decoded, expected);
459        }
460    }
461
462    #[test]
463    #[ignore]
464    fn test_ans_roundtrip_long_sequence() {
465        let dist = uniform_distribution(16).expect("ok");
466        let symbols: Vec<u16> = (0..100).map(|i| (i % 16) as u16).collect();
467
468        let mut encoder = AnsEncoder::new();
469        for &sym in symbols.iter().rev() {
470            encoder.encode_symbol(sym, &dist).expect("ok");
471        }
472        let encoded = encoder.finish();
473
474        let mut decoder = AnsDecoder::new(&encoded).expect("ok");
475        for (i, &expected) in symbols.iter().enumerate() {
476            let decoded = decoder.decode_symbol(&dist).expect("ok");
477            assert_eq!(decoded, expected, "Mismatch at position {i}");
478        }
479    }
480
481    #[test]
482    #[ignore]
483    fn test_empty_distribution_error() {
484        assert!(AnsDistribution::new(vec![], vec![], 10).is_err());
485    }
486
487    #[test]
488    #[ignore]
489    fn test_zero_symbol_uniform_error() {
490        assert!(uniform_distribution(0).is_err());
491    }
492}