jxl_coding/
prefix.rs

1//! Prefix code based on Brotli
2use jxl_bitstream::Bitstream;
3
4use crate::{Error, Result};
5
6const MAX_PREFIX_BITS: usize = 15;
7const MAX_TOPLEVEL_BITS: usize = 10;
8
9#[derive(Debug)]
10pub struct Histogram {
11    toplevel_bits: usize,
12    toplevel_mask: u32,
13    toplevel_entries: Vec<Entry>,
14    second_level_entries: Vec<Entry>,
15}
16
17#[derive(Debug, Copy, Clone, Default)]
18struct Entry {
19    nested: bool,
20    bits_or_mask: u8,
21    symbol_or_offset: u16,
22}
23
24impl Histogram {
25    fn with_code_lengths(code_lengths: Vec<u8>) -> Result<Self> {
26        let mut syms_for_length = Vec::with_capacity(MAX_PREFIX_BITS);
27        for (sym, len) in code_lengths.into_iter().enumerate() {
28            let sym = sym as u16;
29            if len > 0 {
30                if syms_for_length.len() < len as usize {
31                    syms_for_length.resize_with(len as usize, Vec::new);
32                }
33                syms_for_length[len as usize - 1].push(sym);
34            }
35        }
36
37        let toplevel_bits = syms_for_length.len().min(MAX_TOPLEVEL_BITS);
38        let mut entries = vec![Entry::default(); 1 << toplevel_bits];
39        let mut current_bits = 0u16;
40        for (idx, syms) in syms_for_length.iter().enumerate().take(toplevel_bits) {
41            let shifts = toplevel_bits - 1 - idx;
42            for &sym in syms {
43                let entry = Entry {
44                    nested: false,
45                    bits_or_mask: (idx + 1) as u8,
46                    symbol_or_offset: sym,
47                };
48                entries[current_bits as usize..][..(1 << shifts)].fill(entry);
49                current_bits += 1u16 << shifts;
50            }
51        }
52
53        let mut second_level_entries = Vec::new();
54        if toplevel_bits < syms_for_length.len() {
55            let mut remaining_entries = Vec::new();
56            let mut remaining_entry_bits = 0usize;
57            for (idx, syms) in syms_for_length.iter().enumerate().skip(toplevel_bits) {
58                if syms.is_empty() {
59                    continue;
60                }
61
62                let chunk_size_bits = idx + 1 - toplevel_bits;
63                let chunk_size = 1usize << chunk_size_bits;
64                let mut chunk = Vec::with_capacity(chunk_size);
65                if !remaining_entries.is_empty() {
66                    let mult = 1usize << (chunk_size_bits - remaining_entry_bits);
67                    for entry in remaining_entries {
68                        for _ in 0..mult {
69                            chunk.push(entry);
70                        }
71                    }
72                }
73                for &sym in syms {
74                    let entry = Entry {
75                        nested: false,
76                        bits_or_mask: (idx + 1) as u8,
77                        symbol_or_offset: sym,
78                    };
79                    chunk.push(entry);
80                    if chunk.len() == chunk_size {
81                        entries[current_bits as usize] = Entry {
82                            nested: true,
83                            bits_or_mask: (chunk_size - 1) as u8,
84                            symbol_or_offset: second_level_entries.len() as u16,
85                        };
86                        vec_reverse_bits(&chunk, &mut second_level_entries);
87                        current_bits += 1;
88                        chunk = Vec::with_capacity(chunk_size);
89                    }
90                }
91                remaining_entries = chunk;
92                remaining_entry_bits = chunk_size_bits;
93            }
94
95            if !remaining_entries.is_empty() {
96                return Err(Error::InvalidPrefixHistogram);
97            }
98        }
99
100        if current_bits == 1 << toplevel_bits {
101            let mut toplevel_entries = Vec::with_capacity(entries.len());
102            vec_reverse_bits(&entries, &mut toplevel_entries);
103            Ok(Self {
104                toplevel_bits,
105                toplevel_mask: (1 << toplevel_bits) - 1,
106                toplevel_entries,
107                second_level_entries,
108            })
109        } else {
110            Err(Error::InvalidPrefixHistogram)
111        }
112    }
113
114    fn with_single_symbol(symbol: u16) -> Self {
115        let entry = Entry {
116            nested: false,
117            bits_or_mask: 0,
118            symbol_or_offset: symbol,
119        };
120        Self {
121            toplevel_bits: 0,
122            toplevel_mask: 0,
123            toplevel_entries: vec![entry],
124            second_level_entries: Vec::new(),
125        }
126    }
127
128    pub fn parse(bitstream: &mut Bitstream, alphabet_size: u32) -> Result<Self> {
129        if alphabet_size == 1 {
130            return Ok(Self::with_single_symbol(0));
131        }
132
133        if alphabet_size > 1u32 << MAX_PREFIX_BITS {
134            return Err(Error::PrefixSymbolTooLarge(alphabet_size as usize));
135        }
136
137        let hskip = bitstream.read_bits(2)?;
138        if hskip == 1 {
139            Self::parse_simple(bitstream, alphabet_size)
140        } else {
141            Self::parse_complex(bitstream, alphabet_size, hskip)
142        }
143    }
144
145    fn parse_simple(bitstream: &mut Bitstream, alphabet_size: u32) -> Result<Self> {
146        let alphabet_bits = alphabet_size.next_power_of_two().trailing_zeros() as usize;
147        let nsym = bitstream.read_bits(2)? + 1;
148        let it = match nsym {
149            1 => {
150                let sym = bitstream.read_bits(alphabet_bits)?;
151                if sym >= alphabet_size {
152                    return Err(Error::InvalidPrefixHistogram);
153                }
154                return Ok(Self::with_single_symbol(sym as u16));
155            }
156            2 => {
157                let syms = [
158                    0,
159                    0,
160                    bitstream.read_bits(alphabet_bits)? as usize,
161                    bitstream.read_bits(alphabet_bits)? as usize,
162                ];
163
164                syms.into_iter().zip([0u8, 0, 1u8, 1])
165            }
166            3 => {
167                let syms = [
168                    0,
169                    bitstream.read_bits(alphabet_bits)? as usize,
170                    bitstream.read_bits(alphabet_bits)? as usize,
171                    bitstream.read_bits(alphabet_bits)? as usize,
172                ];
173
174                syms.into_iter().zip([0u8, 1, 2, 2])
175            }
176            4 => {
177                let syms = [
178                    bitstream.read_bits(alphabet_bits)? as usize,
179                    bitstream.read_bits(alphabet_bits)? as usize,
180                    bitstream.read_bits(alphabet_bits)? as usize,
181                    bitstream.read_bits(alphabet_bits)? as usize,
182                ];
183                let tree_selector = bitstream.read_bool()?;
184
185                if tree_selector {
186                    syms.into_iter().zip([1u8, 2, 3, 3])
187                } else {
188                    syms.into_iter().zip([2u8, 2, 2, 2])
189                }
190            }
191            _ => unreachable!(),
192        };
193
194        let mut code_lengths = vec![0u8; alphabet_size as usize];
195        for (sym, len) in it {
196            if let Some(out) = code_lengths.get_mut(sym) {
197                *out = len;
198            } else {
199                return Err(Error::InvalidPrefixHistogram);
200            }
201        }
202        Self::with_code_lengths(code_lengths)
203    }
204
205    fn parse_complex(bitstream: &mut Bitstream, alphabet_size: u32, hskip: u32) -> Result<Self> {
206        const CODE_LENGTH_ORDER: [usize; 18] =
207            [1, 2, 3, 4, 0, 5, 17, 6, 16, 7, 8, 9, 10, 11, 12, 13, 14, 15];
208        let mut code_length_code_lengths = [0u8; 18];
209        let mut bitacc = 0usize;
210
211        let mut nonzero_count = 0;
212        let mut nonzero_sym = 0;
213        for idx in CODE_LENGTH_ORDER.into_iter().skip(hskip as usize) {
214            // Read single code length code.
215            let base = bitstream.read_u32(0, 4, 3, 8)? as u8;
216            let len = if base == 8 {
217                if bitstream.read_bool()? {
218                    if bitstream.read_bool()? {
219                        // 1111
220                        5
221                    } else {
222                        // 0111
223                        1
224                    }
225                } else {
226                    // 011
227                    2
228                }
229            } else {
230                base
231            };
232
233            code_length_code_lengths[idx] = len;
234            if len != 0 {
235                nonzero_count += 1;
236                nonzero_sym = idx;
237                bitacc += 32 >> len;
238
239                match bitacc.cmp(&32) {
240                    std::cmp::Ordering::Less => {}
241                    std::cmp::Ordering::Equal => break,
242                    std::cmp::Ordering::Greater => return Err(Error::InvalidPrefixHistogram),
243                }
244            }
245        }
246
247        let code_length_histogram = if nonzero_count == 1 {
248            Histogram::with_single_symbol(nonzero_sym as u16)
249        } else if bitacc != 32 {
250            return Err(Error::InvalidPrefixHistogram);
251        } else {
252            Histogram::with_code_lengths(code_length_code_lengths.to_vec())?
253        };
254
255        let mut code_lengths = vec![0u8; alphabet_size as usize];
256        let mut bitacc = 0usize;
257
258        let mut prev_sym = 8u8;
259        let mut last_nonzero_sym = 8u8;
260        let mut last_repeat_count = 0usize;
261
262        let mut repeat_count = 0usize;
263        let mut repeat_sym = 0u8;
264        for len in &mut code_lengths {
265            if repeat_count > 0 {
266                *len = repeat_sym;
267                repeat_count -= 1;
268            } else {
269                let sym = code_length_histogram.read_symbol(bitstream)? as u8;
270                match sym {
271                    0 => {}
272                    1..=15 => {
273                        *len = sym;
274                        last_nonzero_sym = sym;
275                    }
276                    16 => {
277                        repeat_count = bitstream.peek_bits_prefilled(2) as usize + 3;
278                        bitstream.consume_bits(2)?;
279                        if prev_sym == 16 {
280                            repeat_count += last_repeat_count * 3 - 8;
281                            last_repeat_count += repeat_count;
282                        } else {
283                            last_repeat_count = repeat_count;
284                        }
285                        repeat_sym = last_nonzero_sym;
286
287                        *len = repeat_sym;
288                        repeat_count -= 1;
289                    }
290                    17 => {
291                        repeat_count = bitstream.peek_bits_prefilled(3) as usize + 3;
292                        bitstream.consume_bits(3)?;
293                        if prev_sym == 17 {
294                            repeat_count += last_repeat_count * 7 - 16;
295                            last_repeat_count += repeat_count;
296                        } else {
297                            last_repeat_count = repeat_count;
298                        }
299                        repeat_sym = 0;
300
301                        *len = repeat_sym;
302                        repeat_count -= 1;
303                    }
304                    _ => unreachable!(),
305                }
306                prev_sym = sym;
307            }
308
309            if *len != 0 {
310                bitacc += 1 << MAX_PREFIX_BITS.saturating_sub(*len as usize);
311
312                if bitacc > 1 << MAX_PREFIX_BITS {
313                    return Err(Error::PrefixSymbolTooLarge(bitacc));
314                } else if bitacc == 1 << MAX_PREFIX_BITS && repeat_count == 0 {
315                    break;
316                }
317            }
318        }
319
320        if bitacc != 1 << MAX_PREFIX_BITS || repeat_count > 0 {
321            return Err(Error::InvalidPrefixHistogram);
322        }
323        Self::with_code_lengths(code_lengths)
324    }
325}
326
327impl Histogram {
328    #[inline(always)]
329    pub fn read_symbol(&self, bitstream: &mut Bitstream) -> Result<u32> {
330        let Self {
331            toplevel_bits,
332            toplevel_mask,
333            ref toplevel_entries,
334            ref second_level_entries,
335        } = *self;
336        let peeked = bitstream.peek_bits_const::<MAX_PREFIX_BITS>();
337        let toplevel_offset = peeked & toplevel_mask;
338        let toplevel_entry = toplevel_entries[toplevel_offset as usize];
339        if toplevel_entry.nested {
340            let chunk_offset = (peeked >> toplevel_bits) & (toplevel_entry.bits_or_mask as u32);
341            let second_level_offset = toplevel_entry.symbol_or_offset as u32 + chunk_offset;
342            let second_level_entry = second_level_entries[second_level_offset as usize];
343            bitstream.consume_bits(second_level_entry.bits_or_mask as usize)?;
344            Ok(second_level_entry.symbol_or_offset as u32)
345        } else {
346            bitstream.consume_bits(toplevel_entry.bits_or_mask as usize)?;
347            Ok(toplevel_entry.symbol_or_offset as u32)
348        }
349    }
350
351    #[inline]
352    pub fn single_symbol(&self) -> Option<u32> {
353        if let &[Entry {
354            nested: false,
355            bits_or_mask: 0,
356            symbol_or_offset: symbol,
357        }] = &*self.toplevel_entries
358        {
359            Some(symbol as u32)
360        } else {
361            None
362        }
363    }
364}
365
366fn vec_reverse_bits(v: &[Entry], out: &mut Vec<Entry>) {
367    let len = v.len();
368    debug_assert!(len.is_power_of_two());
369    let bits = len.trailing_zeros();
370    let shift = usize::BITS - bits;
371    for idx in 0..len {
372        let rev_idx = idx.reverse_bits() >> shift;
373        let entry = v[rev_idx];
374        out.push(entry);
375    }
376}
377
378#[cfg(test)]
379mod tests {
380    #[test]
381    fn type_size() {
382        assert_eq!(std::mem::size_of::<super::Entry>(), 4);
383    }
384}