Skip to main content

jxl_encoder/entropy_coding/
context_map.rs

1// Copyright (c) Imazen LLC and the JPEG XL Project Authors.
2// Algorithms and constants derived from libjxl (BSD-3-Clause).
3// Licensed under AGPL-3.0-or-later. Commercial licenses at https://www.imazen.io/pricing
4
5//! Context map encoding.
6//!
7//! Ported from libjxl `lib/jxl/enc_context_map.cc`.
8//!
9//! # Current Implementation
10//!
11//! Currently uses simple encoding only, which supports up to 8 histograms
12//! (3 bits per entry). This covers most practical use cases.
13//!
14//! # Future Enhancement: ANS Encoding
15//!
16//! For better compression with many clusters (>8), the JXL format supports
17//! ANS-based context map encoding with optional MTF transform:
18//! 1. **Prefix code (Huffman)**: Uses Huffman codes for symbols.
19//! 2. **Prefix code with MTF**: Applies move-to-front transform before Huffman.
20//!
21//! This requires implementing the full JXL entropy bundle format, which is
22//! non-trivial. See libjxl `lib/jxl/enc_context_map.cc` for reference.
23
24use crate::bit_writer::BitWriter;
25use crate::error::Result;
26
27/// Move-to-front transform for better compression.
28///
29/// The MTF transform replaces each symbol with its index in a "recently used"
30/// list. Symbols that appear frequently close together get small indices.
31pub fn move_to_front_transform(input: &[u8]) -> Vec<u8> {
32    if input.is_empty() {
33        return Vec::new();
34    }
35
36    let max_value = *input.iter().max().unwrap_or(&0);
37    let mut mtf: Vec<u8> = (0..=max_value).collect();
38    let mut result = Vec::with_capacity(input.len());
39
40    for &value in input {
41        // Find index of value in MTF list
42        let index = mtf.iter().position(|&x| x == value).unwrap();
43        result.push(index as u8);
44
45        // Move to front
46        if index > 0 {
47            let val = mtf.remove(index);
48            mtf.insert(0, val);
49        }
50    }
51
52    result
53}
54
55/// Inverse move-to-front transform (for testing).
56pub fn inverse_move_to_front_transform(input: &[u8], max_symbol: u8) -> Vec<u8> {
57    if input.is_empty() {
58        return Vec::new();
59    }
60
61    let mut mtf: Vec<u8> = (0..=max_symbol).collect();
62    let mut result = Vec::with_capacity(input.len());
63
64    for &index in input {
65        let idx = index as usize;
66        if idx >= mtf.len() {
67            // Invalid index, return what we have
68            result.push(0);
69            continue;
70        }
71
72        let value = mtf[idx];
73        result.push(value);
74
75        // Move to front
76        if idx > 0 {
77            mtf.remove(idx);
78            mtf.insert(0, value);
79        }
80    }
81
82    result
83}
84
85/// Encode context map to bitstream.
86///
87/// The context map maps context indices to histogram indices.
88///
89/// # Current Implementation
90///
91/// Uses simple encoding only, which supports up to 8 histograms (3 bits per entry).
92/// This is sufficient for reasonable clustering without the complexity of implementing
93/// the full ANS entropy bundle format for context maps.
94///
95/// # Encoding Format
96///
97/// - If num_histograms == 1: write (1, 0) → no actual entries needed
98/// - Simple mode: write (1, entry_bits) then each entry with entry_bits bits
99///
100/// # Future Work
101///
102/// To support >8 histograms with efficient encoding, implement the full JXL entropy
103/// bundle format for context maps (is_simple=0 path). This requires:
104/// - lz77.enabled flag
105/// - Full histogram bundle with ANS/prefix codes
106/// - HybridUint encoding for symbols
107///
108/// Reference: libjxl lib/jxl/enc_context_map.cc
109pub fn encode_context_map(
110    context_map: &[u8],
111    num_histograms: usize,
112    writer: &mut BitWriter,
113) -> Result<()> {
114    if num_histograms == 1 {
115        // Simple code: all contexts map to histogram 0
116        writer.write(1, 1)?; // simple flag
117        writer.write(2, 0)?; // 0 bits per entry
118        return Ok(());
119    }
120
121    // Calculate bits needed for simple encoding
122    // Simple mode supports bits_per_entry = 0, 1, 2, or 3 (encoded in 2 bits)
123    // This allows up to 8 histograms (2^3 = 8)
124    let entry_bits = ceil_log2_nonzero(num_histograms);
125
126    if entry_bits > 3 {
127        // Simple mode only supports up to 3 bits per entry (8 clusters)
128        // For now, just use 3 bits and mask values (clustering should ensure <= 8 clusters)
129        crate::trace::debug_eprintln!(
130            "WARNING: context_map requires {} bits but simple mode max is 3 bits. \
131             Using 3 bits, which may cause decoding errors if num_histograms > 8.",
132            entry_bits
133        );
134    }
135
136    let effective_bits = entry_bits.min(3);
137    writer.write(1, 1)?; // simple flag
138    writer.write(2, effective_bits as u64)?; // bits per entry
139    for &entry in context_map {
140        // Mask entry to fit within effective_bits
141        let masked_entry = entry & ((1 << effective_bits) - 1);
142        writer.write(effective_bits, masked_entry as u64)?;
143    }
144
145    Ok(())
146}
147
148/// Ceiling of log2 for non-zero values.
149#[inline]
150fn ceil_log2_nonzero(n: usize) -> usize {
151    if n <= 1 {
152        0
153    } else {
154        (usize::BITS - (n - 1).leading_zeros()) as usize
155    }
156}
157
158#[cfg(test)]
159mod tests {
160    use super::*;
161
162    #[test]
163    fn test_mtf_simple() {
164        let input = vec![1, 2, 1, 2, 1, 2];
165        let transformed = move_to_front_transform(&input);
166
167        // 1 → index 1, mtf = [1, 0, 2]
168        // 2 → index 2, mtf = [2, 1, 0]
169        // 1 → index 1, mtf = [1, 2, 0]
170        // 2 → index 1, mtf = [2, 1, 0]
171        // 1 → index 1, mtf = [1, 2, 0]
172        // 2 → index 1, mtf = [2, 1, 0]
173        assert_eq!(transformed, vec![1, 2, 1, 1, 1, 1]);
174    }
175
176    #[test]
177    fn test_mtf_repeated() {
178        let input = vec![5, 5, 5, 5];
179        let transformed = move_to_front_transform(&input);
180
181        // After first 5, it's at front, so subsequent 5s are at index 0
182        assert_eq!(transformed, vec![5, 0, 0, 0]);
183    }
184
185    #[test]
186    fn test_mtf_empty() {
187        let input: Vec<u8> = vec![];
188        let transformed = move_to_front_transform(&input);
189        assert!(transformed.is_empty());
190    }
191
192    #[test]
193    fn test_mtf_roundtrip() {
194        let original = vec![3, 1, 4, 1, 5, 9, 2, 6, 5, 3];
195        let max_symbol = *original.iter().max().unwrap();
196        let transformed = move_to_front_transform(&original);
197        let recovered = inverse_move_to_front_transform(&transformed, max_symbol);
198        assert_eq!(original, recovered);
199    }
200
201    #[test]
202    fn test_mtf_roundtrip_sequential() {
203        let original: Vec<u8> = (0..10).collect();
204        let max_symbol = *original.iter().max().unwrap();
205        let transformed = move_to_front_transform(&original);
206        let recovered = inverse_move_to_front_transform(&transformed, max_symbol);
207        assert_eq!(original, recovered);
208    }
209
210    #[test]
211    fn test_encode_context_map_single() {
212        let context_map: Vec<u8> = vec![0, 0, 0, 0];
213        let mut writer = BitWriter::new();
214
215        encode_context_map(&context_map, 1, &mut writer).unwrap();
216
217        let bytes = writer.finish_with_padding();
218        // Simple encoding with 0 bits per entry: (1, 0) = 1 bit + 2 bits = 3 bits
219        assert!(!bytes.is_empty());
220    }
221
222    #[test]
223    fn test_encode_context_map_two_histograms() {
224        let context_map: Vec<u8> = vec![0, 1, 0, 1];
225        let mut writer = BitWriter::new();
226
227        encode_context_map(&context_map, 2, &mut writer).unwrap();
228
229        let bytes = writer.finish_with_padding();
230        // Simple encoding with 1 bit per entry
231        assert!(!bytes.is_empty());
232    }
233
234    #[test]
235    fn test_ceil_log2() {
236        assert_eq!(ceil_log2_nonzero(1), 0);
237        assert_eq!(ceil_log2_nonzero(2), 1);
238        assert_eq!(ceil_log2_nonzero(3), 2);
239        assert_eq!(ceil_log2_nonzero(4), 2);
240        assert_eq!(ceil_log2_nonzero(5), 3);
241        assert_eq!(ceil_log2_nonzero(8), 3);
242        assert_eq!(ceil_log2_nonzero(9), 4);
243        assert_eq!(ceil_log2_nonzero(256), 8);
244    }
245}