Skip to main content

jxl_coding/
prefix.rs

1//! Prefix code based on Brotli
2use jxl_bitstream::Bitstream;
3
4use crate::{CodingResult, Error};
5
6const MAX_PREFIX_BITS: usize = 15;
7const MAX_TOPLEVEL_BITS: usize = 10;
8
9#[derive(Debug)]
10pub struct Histogram {
11    toplevel_bits: usize,
12    toplevel_mask: u32,
13    toplevel_entries: Vec<Entry>,
14    second_level_entries: Vec<Entry>,
15}
16
17#[derive(Debug, Copy, Clone, Default)]
18struct Entry {
19    nested: bool,
20    bits_or_mask: u8,
21    symbol_or_offset: u16,
22}
23
24const _: () = {
25    ["size of `struct Entry`"][std::mem::size_of::<Entry>() - 4];
26};
27
28impl Histogram {
29    fn with_code_lengths(code_lengths: Vec<u8>) -> CodingResult<Self> {
30        let mut syms_for_length = Vec::with_capacity(MAX_PREFIX_BITS);
31        for (sym, len) in code_lengths.into_iter().enumerate() {
32            let sym = sym as u16;
33            if len > 0 {
34                if syms_for_length.len() < len as usize {
35                    syms_for_length.resize_with(len as usize, Vec::new);
36                }
37                syms_for_length[len as usize - 1].push(sym);
38            }
39        }
40
41        let toplevel_bits = syms_for_length.len().min(MAX_TOPLEVEL_BITS);
42        let mut entries = vec![Entry::default(); 1 << toplevel_bits];
43        let mut current_bits = 0u16;
44        for (idx, syms) in syms_for_length.iter().enumerate().take(toplevel_bits) {
45            let shifts = toplevel_bits - 1 - idx;
46            for &sym in syms {
47                let entry = Entry {
48                    nested: false,
49                    bits_or_mask: (idx + 1) as u8,
50                    symbol_or_offset: sym,
51                };
52                entries[current_bits as usize..][..(1 << shifts)].fill(entry);
53                current_bits += 1u16 << shifts;
54            }
55        }
56
57        let mut second_level_entries = Vec::new();
58        if toplevel_bits < syms_for_length.len() {
59            let mut remaining_entries = Vec::new();
60            let mut remaining_entry_bits = 0usize;
61            for (idx, syms) in syms_for_length.iter().enumerate().skip(toplevel_bits) {
62                if syms.is_empty() {
63                    continue;
64                }
65
66                let chunk_size_bits = idx + 1 - toplevel_bits;
67                let chunk_size = 1usize << chunk_size_bits;
68                let mut chunk = Vec::with_capacity(chunk_size);
69                if !remaining_entries.is_empty() {
70                    let mult = 1usize << (chunk_size_bits - remaining_entry_bits);
71                    for entry in remaining_entries {
72                        for _ in 0..mult {
73                            chunk.push(entry);
74                        }
75                    }
76                }
77                for &sym in syms {
78                    let entry = Entry {
79                        nested: false,
80                        bits_or_mask: (idx + 1) as u8,
81                        symbol_or_offset: sym,
82                    };
83                    chunk.push(entry);
84                    if chunk.len() == chunk_size {
85                        entries[current_bits as usize] = Entry {
86                            nested: true,
87                            bits_or_mask: (chunk_size - 1) as u8,
88                            symbol_or_offset: second_level_entries.len() as u16,
89                        };
90                        vec_reverse_bits(&chunk, &mut second_level_entries);
91                        current_bits += 1;
92                        chunk = Vec::with_capacity(chunk_size);
93                    }
94                }
95                remaining_entries = chunk;
96                remaining_entry_bits = chunk_size_bits;
97            }
98
99            if !remaining_entries.is_empty() {
100                return Err(Error::InvalidPrefixHistogram);
101            }
102        }
103
104        if current_bits == 1 << toplevel_bits {
105            let mut toplevel_entries = Vec::with_capacity(entries.len());
106            vec_reverse_bits(&entries, &mut toplevel_entries);
107            Ok(Self {
108                toplevel_bits,
109                toplevel_mask: (1 << toplevel_bits) - 1,
110                toplevel_entries,
111                second_level_entries,
112            })
113        } else {
114            Err(Error::InvalidPrefixHistogram)
115        }
116    }
117
118    fn with_single_symbol(symbol: u16) -> Self {
119        let entry = Entry {
120            nested: false,
121            bits_or_mask: 0,
122            symbol_or_offset: symbol,
123        };
124        Self {
125            toplevel_bits: 0,
126            toplevel_mask: 0,
127            toplevel_entries: vec![entry],
128            second_level_entries: Vec::new(),
129        }
130    }
131
132    pub fn parse(bitstream: &mut Bitstream, alphabet_size: u32) -> CodingResult<Self> {
133        if alphabet_size == 1 {
134            return Ok(Self::with_single_symbol(0));
135        }
136
137        if alphabet_size > 1u32 << MAX_PREFIX_BITS {
138            return Err(Error::PrefixSymbolTooLarge(alphabet_size as usize));
139        }
140
141        let hskip = bitstream.read_bits(2)?;
142        if hskip == 1 {
143            Self::parse_simple(bitstream, alphabet_size)
144        } else {
145            Self::parse_complex(bitstream, alphabet_size, hskip)
146        }
147    }
148
149    fn parse_simple(bitstream: &mut Bitstream, alphabet_size: u32) -> CodingResult<Self> {
150        let alphabet_bits = alphabet_size.next_power_of_two().trailing_zeros() as usize;
151        let nsym = bitstream.read_bits(2)? + 1;
152        let it = match nsym {
153            1 => {
154                let sym = bitstream.read_bits(alphabet_bits)?;
155                if sym >= alphabet_size {
156                    return Err(Error::InvalidPrefixHistogram);
157                }
158                return Ok(Self::with_single_symbol(sym as u16));
159            }
160            2 => {
161                let syms = [
162                    0,
163                    0,
164                    bitstream.read_bits(alphabet_bits)? as usize,
165                    bitstream.read_bits(alphabet_bits)? as usize,
166                ];
167
168                syms.into_iter().zip([0u8, 0, 1u8, 1])
169            }
170            3 => {
171                let syms = [
172                    0,
173                    bitstream.read_bits(alphabet_bits)? as usize,
174                    bitstream.read_bits(alphabet_bits)? as usize,
175                    bitstream.read_bits(alphabet_bits)? as usize,
176                ];
177
178                syms.into_iter().zip([0u8, 1, 2, 2])
179            }
180            4 => {
181                let syms = [
182                    bitstream.read_bits(alphabet_bits)? as usize,
183                    bitstream.read_bits(alphabet_bits)? as usize,
184                    bitstream.read_bits(alphabet_bits)? as usize,
185                    bitstream.read_bits(alphabet_bits)? as usize,
186                ];
187                let tree_selector = bitstream.read_bool()?;
188
189                if tree_selector {
190                    syms.into_iter().zip([1u8, 2, 3, 3])
191                } else {
192                    syms.into_iter().zip([2u8, 2, 2, 2])
193                }
194            }
195            _ => unreachable!(),
196        };
197
198        let mut code_lengths = vec![0u8; alphabet_size as usize];
199        for (sym, len) in it {
200            if let Some(out) = code_lengths.get_mut(sym) {
201                *out = len;
202            } else {
203                return Err(Error::InvalidPrefixHistogram);
204            }
205        }
206        Self::with_code_lengths(code_lengths)
207    }
208
209    fn parse_complex(
210        bitstream: &mut Bitstream,
211        alphabet_size: u32,
212        hskip: u32,
213    ) -> CodingResult<Self> {
214        const CODE_LENGTH_ORDER: [usize; 18] =
215            [1, 2, 3, 4, 0, 5, 17, 6, 16, 7, 8, 9, 10, 11, 12, 13, 14, 15];
216        let mut code_length_code_lengths = [0u8; 18];
217        let mut bitacc = 0usize;
218
219        let mut nonzero_count = 0;
220        let mut nonzero_sym = 0;
221        for idx in CODE_LENGTH_ORDER.into_iter().skip(hskip as usize) {
222            // Read single code length code.
223            let base = bitstream.read_u32(0, 4, 3, 8)? as u8;
224            let len = if base == 8 {
225                if bitstream.read_bool()? {
226                    if bitstream.read_bool()? {
227                        // 1111
228                        5
229                    } else {
230                        // 0111
231                        1
232                    }
233                } else {
234                    // 011
235                    2
236                }
237            } else {
238                base
239            };
240
241            code_length_code_lengths[idx] = len;
242            if len != 0 {
243                nonzero_count += 1;
244                nonzero_sym = idx;
245                bitacc += 32 >> len;
246
247                match bitacc.cmp(&32) {
248                    std::cmp::Ordering::Less => {}
249                    std::cmp::Ordering::Equal => break,
250                    std::cmp::Ordering::Greater => return Err(Error::InvalidPrefixHistogram),
251                }
252            }
253        }
254
255        let code_length_histogram = if nonzero_count == 1 {
256            Histogram::with_single_symbol(nonzero_sym as u16)
257        } else if bitacc != 32 {
258            return Err(Error::InvalidPrefixHistogram);
259        } else {
260            Histogram::with_code_lengths(code_length_code_lengths.to_vec())?
261        };
262
263        let mut code_lengths = vec![0u8; alphabet_size as usize];
264        let mut bitacc = 0usize;
265
266        let mut prev_sym = 8u8;
267        let mut last_nonzero_sym = 8u8;
268        let mut last_repeat_count = 0usize;
269
270        let mut repeat_count = 0usize;
271        let mut repeat_sym = 0u8;
272        for len in &mut code_lengths {
273            if repeat_count > 0 {
274                *len = repeat_sym;
275                repeat_count -= 1;
276            } else {
277                let sym = code_length_histogram.read_symbol(bitstream)? as u8;
278                match sym {
279                    0 => {}
280                    1..=15 => {
281                        *len = sym;
282                        last_nonzero_sym = sym;
283                    }
284                    16 => {
285                        repeat_count = bitstream.peek_bits_prefilled(2) as usize + 3;
286                        bitstream.consume_bits(2)?;
287                        if prev_sym == 16 {
288                            repeat_count += last_repeat_count * 3 - 8;
289                            last_repeat_count += repeat_count;
290                        } else {
291                            last_repeat_count = repeat_count;
292                        }
293                        repeat_sym = last_nonzero_sym;
294
295                        *len = repeat_sym;
296                        repeat_count -= 1;
297                    }
298                    17 => {
299                        repeat_count = bitstream.peek_bits_prefilled(3) as usize + 3;
300                        bitstream.consume_bits(3)?;
301                        if prev_sym == 17 {
302                            repeat_count += last_repeat_count * 7 - 16;
303                            last_repeat_count += repeat_count;
304                        } else {
305                            last_repeat_count = repeat_count;
306                        }
307                        repeat_sym = 0;
308
309                        *len = repeat_sym;
310                        repeat_count -= 1;
311                    }
312                    _ => unreachable!(),
313                }
314                prev_sym = sym;
315            }
316
317            if *len != 0 {
318                bitacc += 1 << MAX_PREFIX_BITS.saturating_sub(*len as usize);
319
320                if bitacc > 1 << MAX_PREFIX_BITS {
321                    return Err(Error::PrefixSymbolTooLarge(bitacc));
322                } else if bitacc == 1 << MAX_PREFIX_BITS && repeat_count == 0 {
323                    break;
324                }
325            }
326        }
327
328        if bitacc != 1 << MAX_PREFIX_BITS || repeat_count > 0 {
329            return Err(Error::InvalidPrefixHistogram);
330        }
331        Self::with_code_lengths(code_lengths)
332    }
333}
334
335impl Histogram {
336    #[inline(always)]
337    pub fn read_symbol(&self, bitstream: &mut Bitstream) -> CodingResult<u32> {
338        let Self {
339            toplevel_bits,
340            toplevel_mask,
341            ref toplevel_entries,
342            ref second_level_entries,
343        } = *self;
344        let peeked = bitstream.peek_bits_const::<MAX_PREFIX_BITS>();
345        let toplevel_offset = peeked & toplevel_mask;
346        let toplevel_entry = toplevel_entries[toplevel_offset as usize];
347        if toplevel_entry.nested {
348            let chunk_offset = (peeked >> toplevel_bits) & (toplevel_entry.bits_or_mask as u32);
349            let second_level_offset = toplevel_entry.symbol_or_offset as u32 + chunk_offset;
350            let second_level_entry = second_level_entries[second_level_offset as usize];
351            bitstream.consume_bits(second_level_entry.bits_or_mask as usize)?;
352            Ok(second_level_entry.symbol_or_offset as u32)
353        } else {
354            bitstream.consume_bits(toplevel_entry.bits_or_mask as usize)?;
355            Ok(toplevel_entry.symbol_or_offset as u32)
356        }
357    }
358
359    #[inline]
360    pub fn single_symbol(&self) -> Option<u32> {
361        if let &[
362            Entry {
363                nested: false,
364                bits_or_mask: 0,
365                symbol_or_offset: symbol,
366            },
367        ] = &*self.toplevel_entries
368        {
369            Some(symbol as u32)
370        } else {
371            None
372        }
373    }
374}
375
376fn vec_reverse_bits(v: &[Entry], out: &mut Vec<Entry>) {
377    let len = v.len();
378    debug_assert!(len.is_power_of_two());
379    let bits = len.trailing_zeros();
380    let shift = usize::BITS - bits;
381    for idx in 0..len {
382        let rev_idx = idx.reverse_bits() >> shift;
383        let entry = v[rev_idx];
384        out.push(entry);
385    }
386}