Skip to main content

jxl_encoder/entropy_coding/
ans_decode.rs

1// Copyright (c) Imazen LLC and the JPEG XL Project Authors.
2// Algorithms and constants derived from libjxl (BSD-3-Clause).
3// Licensed under AGPL-3.0-or-later. Commercial licenses at https://www.imazen.io/pricing
4//
5// ANS decoder ported from jxl-rs for testing purposes.
6
7use crate::error::{Error, Result};
8
9const LOG_SUM_PROBS: usize = 12;
10const SUM_PROBS: u16 = 1 << LOG_SUM_PROBS;
11const RLE_MARKER_SYM: u16 = LOG_SUM_PROBS as u16 + 1;
12
13/// Simple bit reader for testing.
14pub struct BitReader<'a> {
15    data: &'a [u8],
16    bit_pos: usize,
17}
18
19impl<'a> BitReader<'a> {
20    pub fn new(data: &'a [u8]) -> Self {
21        Self { data, bit_pos: 0 }
22    }
23
24    pub fn read(&mut self, n: usize) -> Result<u64> {
25        let mut val = 0u64;
26        for i in 0..n {
27            let byte_idx = self.bit_pos / 8;
28            let bit_idx = self.bit_pos % 8;
29            if byte_idx >= self.data.len() {
30                return Err(Error::Bitstream("unexpected end of data".to_string()));
31            }
32            let bit = ((self.data[byte_idx] >> bit_idx) & 1) as u64;
33            val |= bit << i;
34            self.bit_pos += 1;
35        }
36        Ok(val)
37    }
38
39    pub fn peek(&mut self, n: usize) -> u64 {
40        let old_pos = self.bit_pos;
41        let val = self.read(n).unwrap_or(0);
42        self.bit_pos = old_pos;
43        val
44    }
45
46    pub fn consume(&mut self, n: usize) -> Result<()> {
47        self.bit_pos += n;
48        if self.bit_pos > self.data.len() * 8 {
49            return Err(Error::Bitstream("unexpected end of data".to_string()));
50        }
51        Ok(())
52    }
53
54    pub fn bits_read(&self) -> usize {
55        self.bit_pos
56    }
57}
58
59/// Decoded ANS histogram for decoding.
60#[derive(Debug)]
61pub struct AnsHistogram {
62    pub buckets: Vec<Bucket>,
63    pub log_bucket_size: usize,
64    pub bucket_mask: u32,
65    pub single_symbol: Option<u32>,
66    pub frequencies: Vec<u16>,
67}
68
69#[derive(Debug, Copy, Clone)]
70pub struct Bucket {
71    pub alias_symbol: u8,
72    pub alias_cutoff: u8,
73    pub dist: u16,
74    pub alias_offset: u16,
75    pub alias_dist_xor: u16,
76}
77
78impl AnsHistogram {
79    pub fn decode(br: &mut BitReader, log_alpha_size: usize) -> Result<Self> {
80        debug_assert!((5..=8).contains(&log_alpha_size));
81        let table_size = 1usize << log_alpha_size;
82        let log_bucket_size = LOG_SUM_PROBS - log_alpha_size;
83        let bucket_size = 1u16 << log_bucket_size;
84        let bucket_mask = bucket_size as u32 - 1;
85
86        let mut dist = vec![0u16; table_size];
87        let alphabet_size = if br.read(1)? != 0 {
88            if br.read(1)? != 0 {
89                Self::decode_dist_two_symbols(br, &mut dist)?
90            } else {
91                Self::decode_dist_single_symbol(br, &mut dist)?
92            }
93        } else if br.read(1)? != 0 {
94            Self::decode_dist_evenly_distributed(br, &mut dist)?
95        } else {
96            Self::decode_dist_complex(br, &mut dist)?
97        };
98
99        let frequencies = dist.clone();
100
101        if let Some(single_sym_idx) = dist.iter().position(|&d| d == SUM_PROBS) {
102            let buckets = dist
103                .into_iter()
104                .enumerate()
105                .map(|(i, dist)| Bucket {
106                    dist,
107                    alias_symbol: single_sym_idx as u8,
108                    alias_offset: bucket_size * i as u16,
109                    alias_cutoff: 0,
110                    alias_dist_xor: dist ^ SUM_PROBS,
111                })
112                .collect();
113            return Ok(Self {
114                buckets,
115                log_bucket_size,
116                bucket_mask,
117                single_symbol: Some(single_sym_idx as u32),
118                frequencies,
119            });
120        }
121
122        Ok(Self {
123            buckets: Self::build_alias_map(alphabet_size, log_bucket_size, &dist),
124            log_bucket_size,
125            bucket_mask,
126            single_symbol: None,
127            frequencies,
128        })
129    }
130
131    fn decode_dist_two_symbols(br: &mut BitReader, dist: &mut [u16]) -> Result<usize> {
132        let table_size = dist.len();
133
134        let v0 = Self::read_u8(br)? as usize;
135        let v1 = Self::read_u8(br)? as usize;
136        if v0 == v1 {
137            return Err(Error::InvalidHistogram(
138                "two symbols are the same".to_string(),
139            ));
140        }
141
142        let alphabet_size = v0.max(v1) + 1;
143        if alphabet_size > table_size {
144            return Err(Error::InvalidHistogram("alphabet too large".to_string()));
145        }
146
147        let prob = br.read(LOG_SUM_PROBS)? as u16;
148        dist[v0] = prob;
149        dist[v1] = SUM_PROBS - prob;
150
151        Ok(alphabet_size)
152    }
153
154    fn decode_dist_single_symbol(br: &mut BitReader, dist: &mut [u16]) -> Result<usize> {
155        let table_size = dist.len();
156
157        let val = Self::read_u8(br)? as usize;
158        let alphabet_size = val + 1;
159        if alphabet_size > table_size {
160            return Err(Error::InvalidHistogram("alphabet too large".to_string()));
161        }
162
163        dist[val] = SUM_PROBS;
164
165        Ok(alphabet_size)
166    }
167
168    fn decode_dist_evenly_distributed(br: &mut BitReader, dist: &mut [u16]) -> Result<usize> {
169        let table_size = dist.len();
170
171        let alphabet_size = Self::read_u8(br)? as usize + 1;
172        if alphabet_size > table_size {
173            return Err(Error::InvalidHistogram("alphabet too large".to_string()));
174        }
175
176        let base = SUM_PROBS as usize / alphabet_size;
177        let remainder = SUM_PROBS as usize % alphabet_size;
178        dist[0..remainder].fill(base as u16 + 1);
179        dist[remainder..alphabet_size].fill(base as u16);
180
181        Ok(alphabet_size)
182    }
183
184    fn decode_dist_complex(br: &mut BitReader, dist: &mut [u16]) -> Result<usize> {
185        let table_size = dist.len();
186
187        let mut len = 0usize;
188        while len < 3 {
189            if br.read(1)? != 0 {
190                len += 1;
191            } else {
192                break;
193            }
194        }
195
196        let shift = (br.read(len)? + (1 << len) - 1) as i16;
197        if shift > 13 {
198            return Err(Error::InvalidHistogram("shift too large".to_string()));
199        }
200
201        let alphabet_size = Self::read_u8(br)? as usize + 3;
202        if alphabet_size > table_size {
203            return Err(Error::InvalidHistogram("alphabet too large".to_string()));
204        }
205
206        let mut repeat_ranges = Vec::new();
207        let mut omit_data: Option<(u16, usize)> = None;
208        let mut idx = 0;
209        while idx < alphabet_size {
210            dist[idx] = Self::read_prefix(br)?;
211            if dist[idx] == RLE_MARKER_SYM {
212                let repeat_count = Self::read_u8(br)? as usize + 4;
213                if idx + repeat_count > alphabet_size {
214                    return Err(Error::InvalidHistogram("RLE overflow".to_string()));
215                }
216                repeat_ranges.push(idx..(idx + repeat_count));
217                idx += repeat_count;
218                continue;
219            }
220            match &mut omit_data {
221                Some((log, pos)) => {
222                    if dist[idx] > *log {
223                        *log = dist[idx];
224                        *pos = idx;
225                    }
226                }
227                data => {
228                    *data = Some((dist[idx], idx));
229                }
230            }
231            idx += 1;
232        }
233        let Some((_, omit_pos)) = omit_data else {
234            return Err(Error::InvalidHistogram("no omit position".to_string()));
235        };
236        if dist.get(omit_pos + 1) == Some(&RLE_MARKER_SYM) {
237            return Err(Error::InvalidHistogram("RLE after omit".to_string()));
238        }
239
240        let mut repeat_range_idx = 0usize;
241        let mut acc = 0;
242        let mut prev_dist = 0u16;
243        for (idx, code) in dist.iter_mut().enumerate() {
244            if repeat_range_idx < repeat_ranges.len()
245                && repeat_ranges[repeat_range_idx].start <= idx
246            {
247                if repeat_ranges[repeat_range_idx].end == idx {
248                    repeat_range_idx += 1;
249                } else {
250                    *code = prev_dist;
251                    acc += *code;
252                    if acc >= SUM_PROBS {
253                        return Err(Error::InvalidHistogram("sum overflow".to_string()));
254                    }
255                    continue;
256                }
257            }
258
259            if *code == 0 {
260                prev_dist = 0;
261                continue;
262            }
263            if idx == omit_pos {
264                prev_dist = 0;
265                continue;
266            }
267            if *code > 1 {
268                let zeros = (*code - 1) as i16;
269                let bitcount = (shift - ((LOG_SUM_PROBS as i16 - zeros) >> 1)).clamp(0, zeros);
270                *code = (1 << zeros) + ((br.read(bitcount as usize)? as u16) << (zeros - bitcount));
271            }
272
273            prev_dist = *code;
274            acc += *code;
275            if acc >= SUM_PROBS {
276                return Err(Error::InvalidHistogram("sum overflow".to_string()));
277            }
278        }
279        dist[omit_pos] = SUM_PROBS - acc;
280
281        Ok(alphabet_size)
282    }
283
284    /// Public alias map builder for verification/testing.
285    pub fn build_alias_map_from_freqs(
286        alphabet_size: usize,
287        log_bucket_size: usize,
288        dist: &[u16],
289    ) -> Vec<Bucket> {
290        Self::build_alias_map(alphabet_size, log_bucket_size, dist)
291    }
292
293    fn build_alias_map(alphabet_size: usize, log_bucket_size: usize, dist: &[u16]) -> Vec<Bucket> {
294        struct WorkingBucket {
295            dist: u16,
296            alias_symbol: u16,
297            alias_offset: u16,
298            alias_cutoff: u16,
299        }
300
301        let bucket_size = 1u16 << log_bucket_size;
302        let mut buckets: Vec<_> = dist
303            .iter()
304            .enumerate()
305            .map(|(i, &dist)| WorkingBucket {
306                dist,
307                alias_symbol: if i < alphabet_size { i as u16 } else { 0 },
308                alias_offset: 0,
309                alias_cutoff: dist,
310            })
311            .collect();
312
313        let mut underfull = Vec::new();
314        let mut overfull = Vec::new();
315        for (idx, bucket) in buckets.iter().enumerate() {
316            match bucket.dist.cmp(&bucket_size) {
317                std::cmp::Ordering::Less => underfull.push(idx),
318                std::cmp::Ordering::Equal => {}
319                std::cmp::Ordering::Greater => overfull.push(idx),
320            }
321        }
322        while let (Some(o), Some(u)) = (overfull.pop(), underfull.pop()) {
323            let by = bucket_size - buckets[u].alias_cutoff;
324            buckets[o].alias_cutoff -= by;
325            buckets[u].alias_symbol = o as u16;
326            buckets[u].alias_offset = buckets[o].alias_cutoff;
327            match buckets[o].alias_cutoff.cmp(&bucket_size) {
328                std::cmp::Ordering::Less => underfull.push(o),
329                std::cmp::Ordering::Equal => {}
330                std::cmp::Ordering::Greater => overfull.push(o),
331            }
332        }
333
334        buckets
335            .iter()
336            .enumerate()
337            .map(|(idx, bucket)| {
338                if bucket.alias_cutoff == bucket_size {
339                    Bucket {
340                        dist: bucket.dist,
341                        alias_symbol: idx as u8,
342                        alias_offset: 0,
343                        alias_cutoff: 0,
344                        alias_dist_xor: 0,
345                    }
346                } else {
347                    Bucket {
348                        dist: bucket.dist,
349                        alias_symbol: bucket.alias_symbol as u8,
350                        alias_offset: bucket.alias_offset - bucket.alias_cutoff,
351                        alias_cutoff: bucket.alias_cutoff as u8,
352                        alias_dist_xor: bucket.dist ^ buckets[bucket.alias_symbol as usize].dist,
353                    }
354                }
355            })
356            .collect()
357    }
358
359    fn read_u8(br: &mut BitReader) -> Result<u8> {
360        Ok(if br.read(1)? != 0 {
361            let n = br.read(3)?;
362            ((1 << n) + br.read(n as usize)?) as u8
363        } else {
364            0
365        })
366    }
367
368    fn read_prefix(br: &mut BitReader) -> Result<u16> {
369        #[rustfmt::skip]
370        const TABLE: [(u8, u8); 128] = [
371            (10, 3), (12, 7), (7, 3), (3, 4), (6, 3), (8, 3), (9, 3), (5, 4),
372            (10, 3), ( 4, 4), (7, 3), (1, 4), (6, 3), (8, 3), (9, 3), (2, 4),
373            (10, 3), ( 0, 5), (7, 3), (3, 4), (6, 3), (8, 3), (9, 3), (5, 4),
374            (10, 3), ( 4, 4), (7, 3), (1, 4), (6, 3), (8, 3), (9, 3), (2, 4),
375            (10, 3), (11, 6), (7, 3), (3, 4), (6, 3), (8, 3), (9, 3), (5, 4),
376            (10, 3), ( 4, 4), (7, 3), (1, 4), (6, 3), (8, 3), (9, 3), (2, 4),
377            (10, 3), ( 0, 5), (7, 3), (3, 4), (6, 3), (8, 3), (9, 3), (5, 4),
378            (10, 3), ( 4, 4), (7, 3), (1, 4), (6, 3), (8, 3), (9, 3), (2, 4),
379            (10, 3), (13, 7), (7, 3), (3, 4), (6, 3), (8, 3), (9, 3), (5, 4),
380            (10, 3), ( 4, 4), (7, 3), (1, 4), (6, 3), (8, 3), (9, 3), (2, 4),
381            (10, 3), ( 0, 5), (7, 3), (3, 4), (6, 3), (8, 3), (9, 3), (5, 4),
382            (10, 3), ( 4, 4), (7, 3), (1, 4), (6, 3), (8, 3), (9, 3), (2, 4),
383            (10, 3), (11, 6), (7, 3), (3, 4), (6, 3), (8, 3), (9, 3), (5, 4),
384            (10, 3), ( 4, 4), (7, 3), (1, 4), (6, 3), (8, 3), (9, 3), (2, 4),
385            (10, 3), ( 0, 5), (7, 3), (3, 4), (6, 3), (8, 3), (9, 3), (5, 4),
386            (10, 3), ( 4, 4), (7, 3), (1, 4), (6, 3), (8, 3), (9, 3), (2, 4),
387        ];
388
389        let index = br.peek(7) as usize;
390        let (sym, bits) = TABLE[index];
391        br.consume(bits as usize)?;
392        Ok(sym as u16)
393    }
394
395    /// Decode a symbol and update state.
396    pub fn read(&self, br: &mut BitReader, state: &mut u32) -> u32 {
397        let idx = *state & 0xfff;
398        let i = (idx >> self.log_bucket_size) as usize;
399        let pos = idx & self.bucket_mask;
400
401        let bucket = &self.buckets[i & (self.buckets.len() - 1)];
402        let alias_symbol = bucket.alias_symbol as u32;
403        let alias_cutoff = bucket.alias_cutoff as u32;
404        let dist = bucket.dist as u32;
405
406        let map_to_alias = (pos >= alias_cutoff) as u32;
407        let offset = (bucket.alias_offset as u32) * map_to_alias;
408        let dist_xor = (bucket.alias_dist_xor as u32) * map_to_alias;
409
410        let dist = dist ^ dist_xor;
411        let symbol = (alias_symbol * map_to_alias) | (i as u32 * (1 - map_to_alias));
412        let offset = offset + pos;
413
414        let next_state = (*state >> LOG_SUM_PROBS) * dist + offset;
415        let select_appended = (next_state < (1 << 16)) as u32;
416        let appended_bits = br.peek(16) as u32;
417        let appended_state = (next_state << 16) | appended_bits;
418        *state = (appended_state * select_appended) | (next_state * (1 - select_appended));
419        if select_appended != 0 {
420            br.consume(16).ok();
421        }
422        symbol
423    }
424}
425
426/// ANS state reader.
427pub struct AnsReader(pub u32);
428
429impl AnsReader {
430    pub const CHECKSUM: u32 = 0x130000;
431
432    pub fn init(br: &mut BitReader) -> Result<Self> {
433        let initial_state = br.read(32)? as u32;
434        Ok(Self(initial_state))
435    }
436
437    pub fn check_final_state(&self) -> Result<()> {
438        if self.0 == Self::CHECKSUM {
439            Ok(())
440        } else {
441            Err(Error::Bitstream(format!(
442                "ANS checksum mismatch: got 0x{:08x}, expected 0x{:08x}",
443                self.0,
444                Self::CHECKSUM
445            )))
446        }
447    }
448
449    pub fn state(&self) -> u32 {
450        self.0
451    }
452}
453
454#[cfg(test)]
455mod tests {
456    use super::*;
457    use crate::bit_writer::BitWriter;
458    use crate::entropy_coding::ans::{ANSEncodingHistogram, ANSHistogramStrategy};
459    use crate::entropy_coding::histogram::Histogram;
460
461    #[test]
462    fn test_decode_single_symbol() {
463        // Create and write a single-symbol histogram
464        let histo = Histogram::from_counts(&[100, 0, 0, 0]);
465        let ans_histo =
466            ANSEncodingHistogram::from_histogram(&histo, ANSHistogramStrategy::Precise).unwrap();
467
468        let mut writer = BitWriter::new();
469        ans_histo.write(&mut writer).unwrap();
470        let bytes = writer.finish_with_padding();
471
472        println!("Single symbol histogram bytes: {:02x?}", bytes);
473
474        // Decode it back
475        let mut br = BitReader::new(&bytes);
476        let decoded = AnsHistogram::decode(&mut br, 6).unwrap();
477
478        println!("Decoded frequencies: {:?}", &decoded.frequencies[..4]);
479        println!("Single symbol: {:?}", decoded.single_symbol);
480
481        // Verify
482        assert_eq!(decoded.single_symbol, Some(0));
483        assert_eq!(decoded.frequencies[0], 4096);
484    }
485
486    #[test]
487    fn test_decode_two_symbols() {
488        // Create and write a two-symbol histogram
489        let histo = Histogram::from_counts(&[100, 100, 0, 0]);
490        let ans_histo =
491            ANSEncodingHistogram::from_histogram(&histo, ANSHistogramStrategy::Precise).unwrap();
492
493        println!("Two symbol histogram: {:?}", ans_histo.counts);
494
495        let mut writer = BitWriter::new();
496        ans_histo.write(&mut writer).unwrap();
497        let bytes = writer.finish_with_padding();
498
499        println!("Two symbol histogram bytes: {:02x?}", bytes);
500
501        // Decode it back
502        let mut br = BitReader::new(&bytes);
503        let decoded = AnsHistogram::decode(&mut br, 6).unwrap();
504
505        println!("Decoded frequencies: {:?}", &decoded.frequencies[..4]);
506
507        // Verify sum
508        let sum: u16 = decoded.frequencies.iter().sum();
509        assert_eq!(sum, 4096, "Sum should be 4096");
510
511        // Verify the two non-zero entries match what we wrote
512        assert_eq!(decoded.frequencies[0], ans_histo.counts[0] as u16);
513        assert_eq!(decoded.frequencies[1], ans_histo.counts[1] as u16);
514    }
515
516    #[test]
517    fn test_decode_general_histogram() {
518        // Create and write a general histogram
519        let histo = Histogram::from_counts(&[100, 50, 25, 10]);
520        let ans_histo =
521            ANSEncodingHistogram::from_histogram(&histo, ANSHistogramStrategy::Precise).unwrap();
522
523        println!("General histogram:");
524        println!("  counts: {:?}", ans_histo.counts);
525        println!(
526            "  method: {}, alphabet_size: {}, omit_pos: {}",
527            ans_histo.method, ans_histo.alphabet_size, ans_histo.omit_pos
528        );
529
530        let mut writer = BitWriter::new();
531        ans_histo.write(&mut writer).unwrap();
532        let bytes = writer.finish_with_padding();
533
534        println!("  bytes ({} bytes): {:02x?}", bytes.len(), bytes);
535
536        // Decode it back
537        let mut br = BitReader::new(&bytes);
538        let decoded = AnsHistogram::decode(&mut br, 6).unwrap();
539
540        println!(
541            "Decoded frequencies: {:?}",
542            &decoded.frequencies[..ans_histo.alphabet_size]
543        );
544
545        // Verify sum
546        let sum: u16 = decoded.frequencies.iter().sum();
547        assert_eq!(sum, 4096, "Sum should be 4096");
548
549        // Verify frequencies match what we wrote
550        for i in 0..ans_histo.alphabet_size {
551            assert_eq!(
552                decoded.frequencies[i], ans_histo.counts[i] as u16,
553                "Frequency mismatch at symbol {}",
554                i
555            );
556        }
557    }
558
559    #[test]
560    fn test_decode_sparse_histogram_roundtrip() {
561        // Reproduce the exact histogram that fails:
562        // alphabet_size=36, symbols at positions 1 (4092), 31 (2), 35 (2)
563        // This is the histogram that caused gradient_256 tree learning to fail.
564        let mut raw_counts = vec![0i32; 40]; // padded to HISTOGRAM_ROUNDING
565        raw_counts[1] = 196000; // dominant symbol
566        raw_counts[31] = 100; // rare symbol
567        raw_counts[35] = 100; // rare symbol
568        let histo = Histogram::from_counts(&raw_counts);
569
570        let ans_histo =
571            ANSEncodingHistogram::from_histogram(&histo, ANSHistogramStrategy::Precise).unwrap();
572
573        println!("Sparse histogram:");
574        println!(
575            "  method={}, alphabet_size={}, omit_pos={}",
576            ans_histo.method, ans_histo.alphabet_size, ans_histo.omit_pos
577        );
578        println!("  non-zero counts:");
579        for (i, &c) in ans_histo.counts.iter().enumerate() {
580            if c != 0 {
581                println!("    [{}] = {}", i, c);
582            }
583        }
584
585        let mut writer = BitWriter::new();
586        ans_histo.write(&mut writer).unwrap();
587        // Add padding so decoder's peek(7) doesn't read past end
588        writer.write(8, 0).unwrap();
589        writer.zero_pad_to_byte();
590        let bytes = writer.finish();
591
592        println!(
593            "  encoded bytes ({} bytes): {:02x?}",
594            bytes.len(),
595            &bytes[..bytes.len().min(32)]
596        );
597
598        // Decode it back
599        let mut br = BitReader::new(&bytes);
600        let decoded = AnsHistogram::decode(&mut br, 6).unwrap();
601
602        println!("  decoded frequencies:");
603        for (i, &f) in decoded.frequencies.iter().enumerate() {
604            if f != 0 {
605                println!("    [{}] = {}", i, f);
606            }
607        }
608
609        // Verify frequencies match
610        let sum: u16 = decoded.frequencies.iter().sum();
611        assert_eq!(sum, 4096, "Sum should be 4096 but got {}", sum);
612
613        for i in 0..ans_histo.alphabet_size {
614            assert_eq!(
615                decoded.frequencies[i], ans_histo.counts[i] as u16,
616                "Frequency mismatch at symbol {}: encoder wrote {}, decoder read {}",
617                i, ans_histo.counts[i], decoded.frequencies[i]
618            );
619        }
620    }
621}