1use jxl_bitstream::Bitstream;
3
4use crate::{CodingResult, Error};
5
6const MAX_PREFIX_BITS: usize = 15;
7const MAX_TOPLEVEL_BITS: usize = 10;
8
9#[derive(Debug)]
10pub struct Histogram {
11 toplevel_bits: usize,
12 toplevel_mask: u32,
13 toplevel_entries: Vec<Entry>,
14 second_level_entries: Vec<Entry>,
15}
16
17#[derive(Debug, Copy, Clone, Default)]
18struct Entry {
19 nested: bool,
20 bits_or_mask: u8,
21 symbol_or_offset: u16,
22}
23
24const _: () = {
25 ["size of `struct Entry`"][std::mem::size_of::<Entry>() - 4];
26};
27
28impl Histogram {
29 fn with_code_lengths(code_lengths: Vec<u8>) -> CodingResult<Self> {
30 let mut syms_for_length = Vec::with_capacity(MAX_PREFIX_BITS);
31 for (sym, len) in code_lengths.into_iter().enumerate() {
32 let sym = sym as u16;
33 if len > 0 {
34 if syms_for_length.len() < len as usize {
35 syms_for_length.resize_with(len as usize, Vec::new);
36 }
37 syms_for_length[len as usize - 1].push(sym);
38 }
39 }
40
41 let toplevel_bits = syms_for_length.len().min(MAX_TOPLEVEL_BITS);
42 let mut entries = vec![Entry::default(); 1 << toplevel_bits];
43 let mut current_bits = 0u16;
44 for (idx, syms) in syms_for_length.iter().enumerate().take(toplevel_bits) {
45 let shifts = toplevel_bits - 1 - idx;
46 for &sym in syms {
47 let entry = Entry {
48 nested: false,
49 bits_or_mask: (idx + 1) as u8,
50 symbol_or_offset: sym,
51 };
52 entries[current_bits as usize..][..(1 << shifts)].fill(entry);
53 current_bits += 1u16 << shifts;
54 }
55 }
56
57 let mut second_level_entries = Vec::new();
58 if toplevel_bits < syms_for_length.len() {
59 let mut remaining_entries = Vec::new();
60 let mut remaining_entry_bits = 0usize;
61 for (idx, syms) in syms_for_length.iter().enumerate().skip(toplevel_bits) {
62 if syms.is_empty() {
63 continue;
64 }
65
66 let chunk_size_bits = idx + 1 - toplevel_bits;
67 let chunk_size = 1usize << chunk_size_bits;
68 let mut chunk = Vec::with_capacity(chunk_size);
69 if !remaining_entries.is_empty() {
70 let mult = 1usize << (chunk_size_bits - remaining_entry_bits);
71 for entry in remaining_entries {
72 for _ in 0..mult {
73 chunk.push(entry);
74 }
75 }
76 }
77 for &sym in syms {
78 let entry = Entry {
79 nested: false,
80 bits_or_mask: (idx + 1) as u8,
81 symbol_or_offset: sym,
82 };
83 chunk.push(entry);
84 if chunk.len() == chunk_size {
85 entries[current_bits as usize] = Entry {
86 nested: true,
87 bits_or_mask: (chunk_size - 1) as u8,
88 symbol_or_offset: second_level_entries.len() as u16,
89 };
90 vec_reverse_bits(&chunk, &mut second_level_entries);
91 current_bits += 1;
92 chunk = Vec::with_capacity(chunk_size);
93 }
94 }
95 remaining_entries = chunk;
96 remaining_entry_bits = chunk_size_bits;
97 }
98
99 if !remaining_entries.is_empty() {
100 return Err(Error::InvalidPrefixHistogram);
101 }
102 }
103
104 if current_bits == 1 << toplevel_bits {
105 let mut toplevel_entries = Vec::with_capacity(entries.len());
106 vec_reverse_bits(&entries, &mut toplevel_entries);
107 Ok(Self {
108 toplevel_bits,
109 toplevel_mask: (1 << toplevel_bits) - 1,
110 toplevel_entries,
111 second_level_entries,
112 })
113 } else {
114 Err(Error::InvalidPrefixHistogram)
115 }
116 }
117
118 fn with_single_symbol(symbol: u16) -> Self {
119 let entry = Entry {
120 nested: false,
121 bits_or_mask: 0,
122 symbol_or_offset: symbol,
123 };
124 Self {
125 toplevel_bits: 0,
126 toplevel_mask: 0,
127 toplevel_entries: vec![entry],
128 second_level_entries: Vec::new(),
129 }
130 }
131
132 pub fn parse(bitstream: &mut Bitstream, alphabet_size: u32) -> CodingResult<Self> {
133 if alphabet_size == 1 {
134 return Ok(Self::with_single_symbol(0));
135 }
136
137 if alphabet_size > 1u32 << MAX_PREFIX_BITS {
138 return Err(Error::PrefixSymbolTooLarge(alphabet_size as usize));
139 }
140
141 let hskip = bitstream.read_bits(2)?;
142 if hskip == 1 {
143 Self::parse_simple(bitstream, alphabet_size)
144 } else {
145 Self::parse_complex(bitstream, alphabet_size, hskip)
146 }
147 }
148
149 fn parse_simple(bitstream: &mut Bitstream, alphabet_size: u32) -> CodingResult<Self> {
150 let alphabet_bits = alphabet_size.next_power_of_two().trailing_zeros() as usize;
151 let nsym = bitstream.read_bits(2)? + 1;
152 let it = match nsym {
153 1 => {
154 let sym = bitstream.read_bits(alphabet_bits)?;
155 if sym >= alphabet_size {
156 return Err(Error::InvalidPrefixHistogram);
157 }
158 return Ok(Self::with_single_symbol(sym as u16));
159 }
160 2 => {
161 let syms = [
162 0,
163 0,
164 bitstream.read_bits(alphabet_bits)? as usize,
165 bitstream.read_bits(alphabet_bits)? as usize,
166 ];
167
168 syms.into_iter().zip([0u8, 0, 1u8, 1])
169 }
170 3 => {
171 let syms = [
172 0,
173 bitstream.read_bits(alphabet_bits)? as usize,
174 bitstream.read_bits(alphabet_bits)? as usize,
175 bitstream.read_bits(alphabet_bits)? as usize,
176 ];
177
178 syms.into_iter().zip([0u8, 1, 2, 2])
179 }
180 4 => {
181 let syms = [
182 bitstream.read_bits(alphabet_bits)? as usize,
183 bitstream.read_bits(alphabet_bits)? as usize,
184 bitstream.read_bits(alphabet_bits)? as usize,
185 bitstream.read_bits(alphabet_bits)? as usize,
186 ];
187 let tree_selector = bitstream.read_bool()?;
188
189 if tree_selector {
190 syms.into_iter().zip([1u8, 2, 3, 3])
191 } else {
192 syms.into_iter().zip([2u8, 2, 2, 2])
193 }
194 }
195 _ => unreachable!(),
196 };
197
198 let mut code_lengths = vec![0u8; alphabet_size as usize];
199 for (sym, len) in it {
200 if let Some(out) = code_lengths.get_mut(sym) {
201 *out = len;
202 } else {
203 return Err(Error::InvalidPrefixHistogram);
204 }
205 }
206 Self::with_code_lengths(code_lengths)
207 }
208
209 fn parse_complex(
210 bitstream: &mut Bitstream,
211 alphabet_size: u32,
212 hskip: u32,
213 ) -> CodingResult<Self> {
214 const CODE_LENGTH_ORDER: [usize; 18] =
215 [1, 2, 3, 4, 0, 5, 17, 6, 16, 7, 8, 9, 10, 11, 12, 13, 14, 15];
216 let mut code_length_code_lengths = [0u8; 18];
217 let mut bitacc = 0usize;
218
219 let mut nonzero_count = 0;
220 let mut nonzero_sym = 0;
221 for idx in CODE_LENGTH_ORDER.into_iter().skip(hskip as usize) {
222 let base = bitstream.read_u32(0, 4, 3, 8)? as u8;
224 let len = if base == 8 {
225 if bitstream.read_bool()? {
226 if bitstream.read_bool()? {
227 5
229 } else {
230 1
232 }
233 } else {
234 2
236 }
237 } else {
238 base
239 };
240
241 code_length_code_lengths[idx] = len;
242 if len != 0 {
243 nonzero_count += 1;
244 nonzero_sym = idx;
245 bitacc += 32 >> len;
246
247 match bitacc.cmp(&32) {
248 std::cmp::Ordering::Less => {}
249 std::cmp::Ordering::Equal => break,
250 std::cmp::Ordering::Greater => return Err(Error::InvalidPrefixHistogram),
251 }
252 }
253 }
254
255 let code_length_histogram = if nonzero_count == 1 {
256 Histogram::with_single_symbol(nonzero_sym as u16)
257 } else if bitacc != 32 {
258 return Err(Error::InvalidPrefixHistogram);
259 } else {
260 Histogram::with_code_lengths(code_length_code_lengths.to_vec())?
261 };
262
263 let mut code_lengths = vec![0u8; alphabet_size as usize];
264 let mut bitacc = 0usize;
265
266 let mut prev_sym = 8u8;
267 let mut last_nonzero_sym = 8u8;
268 let mut last_repeat_count = 0usize;
269
270 let mut repeat_count = 0usize;
271 let mut repeat_sym = 0u8;
272 for len in &mut code_lengths {
273 if repeat_count > 0 {
274 *len = repeat_sym;
275 repeat_count -= 1;
276 } else {
277 let sym = code_length_histogram.read_symbol(bitstream)? as u8;
278 match sym {
279 0 => {}
280 1..=15 => {
281 *len = sym;
282 last_nonzero_sym = sym;
283 }
284 16 => {
285 repeat_count = bitstream.peek_bits_prefilled(2) as usize + 3;
286 bitstream.consume_bits(2)?;
287 if prev_sym == 16 {
288 repeat_count += last_repeat_count * 3 - 8;
289 last_repeat_count += repeat_count;
290 } else {
291 last_repeat_count = repeat_count;
292 }
293 repeat_sym = last_nonzero_sym;
294
295 *len = repeat_sym;
296 repeat_count -= 1;
297 }
298 17 => {
299 repeat_count = bitstream.peek_bits_prefilled(3) as usize + 3;
300 bitstream.consume_bits(3)?;
301 if prev_sym == 17 {
302 repeat_count += last_repeat_count * 7 - 16;
303 last_repeat_count += repeat_count;
304 } else {
305 last_repeat_count = repeat_count;
306 }
307 repeat_sym = 0;
308
309 *len = repeat_sym;
310 repeat_count -= 1;
311 }
312 _ => unreachable!(),
313 }
314 prev_sym = sym;
315 }
316
317 if *len != 0 {
318 bitacc += 1 << MAX_PREFIX_BITS.saturating_sub(*len as usize);
319
320 if bitacc > 1 << MAX_PREFIX_BITS {
321 return Err(Error::PrefixSymbolTooLarge(bitacc));
322 } else if bitacc == 1 << MAX_PREFIX_BITS && repeat_count == 0 {
323 break;
324 }
325 }
326 }
327
328 if bitacc != 1 << MAX_PREFIX_BITS || repeat_count > 0 {
329 return Err(Error::InvalidPrefixHistogram);
330 }
331 Self::with_code_lengths(code_lengths)
332 }
333}
334
335impl Histogram {
336 #[inline(always)]
337 pub fn read_symbol(&self, bitstream: &mut Bitstream) -> CodingResult<u32> {
338 let Self {
339 toplevel_bits,
340 toplevel_mask,
341 ref toplevel_entries,
342 ref second_level_entries,
343 } = *self;
344 let peeked = bitstream.peek_bits_const::<MAX_PREFIX_BITS>();
345 let toplevel_offset = peeked & toplevel_mask;
346 let toplevel_entry = toplevel_entries[toplevel_offset as usize];
347 if toplevel_entry.nested {
348 let chunk_offset = (peeked >> toplevel_bits) & (toplevel_entry.bits_or_mask as u32);
349 let second_level_offset = toplevel_entry.symbol_or_offset as u32 + chunk_offset;
350 let second_level_entry = second_level_entries[second_level_offset as usize];
351 bitstream.consume_bits(second_level_entry.bits_or_mask as usize)?;
352 Ok(second_level_entry.symbol_or_offset as u32)
353 } else {
354 bitstream.consume_bits(toplevel_entry.bits_or_mask as usize)?;
355 Ok(toplevel_entry.symbol_or_offset as u32)
356 }
357 }
358
359 #[inline]
360 pub fn single_symbol(&self) -> Option<u32> {
361 if let &[
362 Entry {
363 nested: false,
364 bits_or_mask: 0,
365 symbol_or_offset: symbol,
366 },
367 ] = &*self.toplevel_entries
368 {
369 Some(symbol as u32)
370 } else {
371 None
372 }
373 }
374}
375
376fn vec_reverse_bits(v: &[Entry], out: &mut Vec<Entry>) {
377 let len = v.len();
378 debug_assert!(len.is_power_of_two());
379 let bits = len.trailing_zeros();
380 let shift = usize::BITS - bits;
381 for idx in 0..len {
382 let rev_idx = idx.reverse_bits() >> shift;
383 let entry = v[rev_idx];
384 out.push(entry);
385 }
386}