1use std::sync::Arc;
8
9use jxl_bitstream::{Bitstream, U};
10
11mod ans;
12mod error;
13mod permutation;
14mod prefix;
15
16pub use error::Error;
17
18pub type CodingResult<T> = std::result::Result<T, Error>;
20
21pub use permutation::read_permutation;
22
23#[derive(Debug, Clone)]
25pub struct Decoder {
26 lz77: Lz77,
27 inner: DecoderInner,
28}
29
30impl Decoder {
31 pub fn parse(bitstream: &mut Bitstream, num_dist: u32) -> CodingResult<Self> {
34 let lz77 = Lz77::parse(bitstream)?;
35 let num_dist = if let Lz77::Disabled = &lz77 {
36 num_dist
37 } else {
38 num_dist + 1
39 };
40 let inner = DecoderInner::parse(bitstream, num_dist)?;
41 Ok(Self { lz77, inner })
42 }
43
44 fn parse_assume_no_lz77(bitstream: &mut Bitstream, num_dist: u32) -> CodingResult<Self> {
45 let lz77_enabled = bitstream.read_bool()?;
46 if lz77_enabled {
47 return Err(Error::Lz77NotAllowed);
48 }
49 let inner = DecoderInner::parse(bitstream, num_dist)?;
50 Ok(Self {
51 lz77: Lz77::Disabled,
52 inner,
53 })
54 }
55
56 #[inline]
58 pub fn read_varint(&mut self, bitstream: &mut Bitstream, ctx: u32) -> CodingResult<u32> {
59 self.read_varint_with_multiplier(bitstream, ctx, 0)
60 }
61
62 #[inline]
64 pub fn read_varint_with_multiplier(
65 &mut self,
66 bitstream: &mut Bitstream,
67 ctx: u32,
68 dist_multiplier: u32,
69 ) -> CodingResult<u32> {
70 let cluster = self.inner.clusters[ctx as usize];
71 self.read_varint_with_multiplier_clustered(bitstream, cluster, dist_multiplier)
72 }
73
74 #[inline(always)]
78 pub fn read_varint_with_multiplier_clustered(
79 &mut self,
80 bitstream: &mut Bitstream,
81 cluster: u8,
82 dist_multiplier: u32,
83 ) -> CodingResult<u32> {
84 if let Lz77::Enabled {
85 ref mut state,
86 min_symbol,
87 min_length,
88 } = self.lz77
89 {
90 self.inner.read_varint_with_multiplier_clustered_lz77(
91 bitstream,
92 cluster,
93 dist_multiplier,
94 state,
95 min_symbol,
96 min_length,
97 )
98 } else {
99 self.inner
100 .read_varint_with_multiplier_clustered(bitstream, cluster)
101 }
102 }
103
104 pub fn as_rle(&mut self) -> Option<DecoderRleMode<'_>> {
106 let &Lz77::Enabled {
107 ref state,
108 min_symbol,
109 min_length,
110 } = &self.lz77
111 else {
112 return None;
113 };
114 let lz_cluster = self.inner.lz_dist_cluster();
115 let lz_conf = &self.inner.configs[lz_cluster as usize];
116 let sym = self.inner.code.single_symbol(lz_cluster)?;
117 (sym == 1 && lz_conf.split_exponent == 0).then_some(DecoderRleMode {
118 inner: &mut self.inner,
119 min_symbol,
120 min_length,
121 len_config: state.lz_len_conf.clone(),
122 })
123 }
124
125 pub fn as_with_lz77(&mut self) -> Option<DecoderWithLz77<'_>> {
127 if let Lz77::Enabled {
128 ref mut state,
129 min_symbol,
130 min_length,
131 } = self.lz77
132 {
133 Some(DecoderWithLz77 {
134 inner: &mut self.inner,
135 state,
136 min_symbol,
137 min_length,
138 })
139 } else {
140 None
141 }
142 }
143
144 pub fn as_no_lz77(&mut self) -> Option<DecoderNoLz77<'_>> {
146 if let Lz77::Disabled = self.lz77 {
147 Some(DecoderNoLz77(&mut self.inner))
148 } else {
149 None
150 }
151 }
152
153 #[inline]
155 pub fn single_token(&self, cluster: u8) -> Option<u32> {
156 if let Lz77::Enabled { .. } = self.lz77 {
157 return None;
158 }
159 self.inner.single_token(cluster)
160 }
161
162 #[inline]
167 pub fn begin(&mut self, bitstream: &mut Bitstream) -> CodingResult<()> {
168 self.inner.code.begin(bitstream)
169 }
170
171 #[inline]
176 pub fn finalize(&self) -> CodingResult<()> {
177 self.inner.code.finalize()
178 }
179
180 #[inline]
182 pub fn cluster_map(&self) -> &[u8] {
183 &self.inner.clusters
184 }
185}
186
187#[derive(Debug)]
189pub struct DecoderRleMode<'dec> {
190 inner: &'dec mut DecoderInner,
191 min_symbol: u32,
192 min_length: u32,
193 len_config: IntegerConfig,
194}
195
196#[derive(Debug, Copy, Clone)]
198pub enum RleToken {
199 Value(u32),
201 Repeat(u32),
203}
204
205impl DecoderRleMode<'_> {
206 #[inline]
210 pub fn read_varint_clustered(
211 &mut self,
212 bitstream: &mut Bitstream,
213 cluster: u8,
214 ) -> CodingResult<RleToken> {
215 self.inner
216 .code
217 .read_symbol(bitstream, cluster)
218 .map(|token| {
219 if let Some(token) = token.checked_sub(self.min_symbol) {
220 RleToken::Repeat(
221 self.inner
222 .read_uint_prefilled(bitstream, &self.len_config, token)
223 + self.min_length,
224 )
225 } else {
226 RleToken::Value(self.inner.read_uint_prefilled(
227 bitstream,
228 &self.inner.configs[cluster as usize],
229 token,
230 ))
231 }
232 })
233 }
234
235 #[inline]
237 pub fn cluster_map(&self) -> &[u8] {
238 &self.inner.clusters
239 }
240}
241
242#[derive(Debug)]
244pub struct DecoderWithLz77<'dec> {
245 inner: &'dec mut DecoderInner,
246 state: &'dec mut Lz77State,
247 min_symbol: u32,
248 min_length: u32,
249}
250
251impl DecoderWithLz77<'_> {
252 #[inline]
256 pub fn read_varint_with_multiplier_clustered(
257 &mut self,
258 bitstream: &mut Bitstream,
259 cluster: u8,
260 dist_multiplier: u32,
261 ) -> CodingResult<u32> {
262 self.inner.read_varint_with_multiplier_clustered_lz77(
263 bitstream,
264 cluster,
265 dist_multiplier,
266 self.state,
267 self.min_symbol,
268 self.min_length,
269 )
270 }
271
272 #[inline]
274 pub fn cluster_map(&self) -> &[u8] {
275 &self.inner.clusters
276 }
277}
278
279#[derive(Debug)]
281pub struct DecoderNoLz77<'dec>(&'dec mut DecoderInner);
282
283impl DecoderNoLz77<'_> {
284 #[inline]
288 pub fn read_varint_clustered(
289 &mut self,
290 bitstream: &mut Bitstream,
291 cluster: u8,
292 ) -> CodingResult<u32> {
293 self.0
294 .read_varint_with_multiplier_clustered(bitstream, cluster)
295 }
296
297 #[inline]
299 pub fn single_token(&self, cluster: u8) -> Option<u32> {
300 self.0.single_token(cluster)
301 }
302
303 #[inline]
305 pub fn cluster_map(&self) -> &[u8] {
306 &self.0.clusters
307 }
308}
309
310#[derive(Debug, Clone)]
311enum Lz77 {
312 Disabled,
313 Enabled {
314 min_symbol: u32,
315 min_length: u32,
316 state: Lz77State,
317 },
318}
319
320impl Lz77 {
321 fn parse(bitstream: &mut Bitstream) -> CodingResult<Self> {
322 Ok(if bitstream.read_bool()? {
323 let min_symbol = bitstream.read_u32(224, 512, 4096, 8 + U(15))?;
325 let min_length = bitstream.read_u32(3, 4, 5 + U(2), 9 + U(8))?;
326 let lz_len_conf = IntegerConfig::parse(bitstream, 8)?;
327 Self::Enabled {
328 min_symbol,
329 min_length,
330 state: Lz77State::new(lz_len_conf),
331 }
332 } else {
333 Self::Disabled
334 })
335 }
336}
337
338#[derive(Clone)]
339struct Lz77State {
340 lz_len_conf: IntegerConfig,
341 window: Vec<u32>,
342 num_to_copy: u32,
343 copy_pos: u32,
344 num_decoded: u32,
345}
346
347impl std::fmt::Debug for Lz77State {
348 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
349 f.debug_struct("Lz77State")
350 .field("lz_len_conf", &self.lz_len_conf)
351 .field("num_to_copy", &self.num_to_copy)
352 .field("copy_pos", &self.copy_pos)
353 .field("num_decoded", &self.num_decoded)
354 .finish_non_exhaustive()
355 }
356}
357
358impl Lz77State {
359 fn new(lz_len_conf: IntegerConfig) -> Self {
360 Self {
361 lz_len_conf,
362 window: Vec::new(),
363 num_to_copy: 0,
364 copy_pos: 0,
365 num_decoded: 0,
366 }
367 }
368}
369
370#[derive(Debug, Clone)]
371struct IntegerConfig {
372 split_exponent: u32,
373 split: u32,
374 msb_in_token: u32,
375 lsb_in_token: u32,
376}
377
378impl IntegerConfig {
379 fn parse(bitstream: &mut Bitstream, log_alphabet_size: u32) -> CodingResult<Self> {
380 let split_exponent_bits = add_log2_ceil(log_alphabet_size);
381 let split_exponent = bitstream.read_bits(split_exponent_bits as usize)?;
382 let (msb_in_token, lsb_in_token) = if split_exponent != log_alphabet_size {
383 let msb_bits = add_log2_ceil(split_exponent) as usize;
384 let msb_in_token = bitstream.read_bits(msb_bits)?;
385 if msb_in_token > split_exponent {
386 return Err(Error::InvalidIntegerConfig {
387 split_exponent,
388 msb_in_token,
389 lsb_in_token: None,
390 });
391 }
392 let lsb_bits = add_log2_ceil(split_exponent - msb_in_token) as usize;
393 let lsb_in_token = bitstream.read_bits(lsb_bits)?;
394 (msb_in_token, lsb_in_token)
395 } else {
396 (0u32, 0u32)
397 };
398 if lsb_in_token + msb_in_token > split_exponent {
399 return Err(Error::InvalidIntegerConfig {
400 split_exponent,
401 msb_in_token,
402 lsb_in_token: Some(lsb_in_token),
403 });
404 }
405 Ok(Self {
406 split_exponent,
407 split: 1 << split_exponent,
408 msb_in_token,
409 lsb_in_token,
410 })
411 }
412}
413
414#[derive(Debug, Clone)]
415struct DecoderInner {
416 clusters: Vec<u8>, configs: Vec<IntegerConfig>, code: Coder,
419}
420
421impl DecoderInner {
422 fn parse(bitstream: &mut Bitstream, num_dist: u32) -> CodingResult<Self> {
423 let (num_clusters, clusters) = read_clusters(bitstream, num_dist)?;
424 let use_prefix_code = bitstream.read_bool()?;
425 let log_alphabet_size = if use_prefix_code {
426 15
427 } else {
428 bitstream.read_bits(2)? + 5
429 };
430 let configs = (0..num_clusters)
431 .map(|_| IntegerConfig::parse(bitstream, log_alphabet_size))
432 .collect::<CodingResult<Vec<_>>>()?;
433 let code = if use_prefix_code {
434 let counts = (0..num_clusters)
435 .map(|_| -> CodingResult<_> {
436 let count = if bitstream.read_bool()? {
437 let n = bitstream.read_bits(4)? as usize;
438 1 + (1 << n) + bitstream.read_bits(n)?
439 } else {
440 1
441 };
442 if count > 1 << 15 {
443 return Err(Error::InvalidPrefixHistogram);
444 }
445 Ok(count)
446 })
447 .collect::<CodingResult<Vec<_>>>()?;
448 let dist = counts
449 .into_iter()
450 .map(|count| prefix::Histogram::parse(bitstream, count))
451 .collect::<CodingResult<Vec<_>>>()?;
452 Coder::PrefixCode(Arc::new(dist))
453 } else {
454 let dist = (0..num_clusters)
455 .map(|_| ans::Histogram::parse(bitstream, log_alphabet_size))
456 .collect::<CodingResult<Vec<_>>>()?;
457 Coder::Ans {
458 dist: Arc::new(dist),
459 state: 0,
460 initial: true,
461 }
462 };
463 Ok(Self {
464 clusters,
465 configs,
466 code,
467 })
468 }
469
470 #[inline]
471 fn single_token(&self, cluster: u8) -> Option<u32> {
472 let single_symbol = self.code.single_symbol(cluster)?;
473 let IntegerConfig { split, .. } = self.configs[cluster as usize];
474 (single_symbol < split).then_some(single_symbol)
475 }
476
477 #[inline]
478 fn read_varint_with_multiplier_clustered(
479 &mut self,
480 bitstream: &mut Bitstream,
481 cluster: u8,
482 ) -> CodingResult<u32> {
483 let token = self.code.read_symbol(bitstream, cluster)?;
484 Ok(self.read_uint_prefilled(bitstream, &self.configs[cluster as usize], token))
485 }
486
487 fn read_varint_with_multiplier_clustered_lz77(
488 &mut self,
489 bitstream: &mut Bitstream,
490 cluster: u8,
491 dist_multiplier: u32,
492 state: &mut Lz77State,
493 min_symbol: u32,
494 min_length: u32,
495 ) -> CodingResult<u32> {
496 #[rustfmt::skip]
497 const SPECIAL_DISTANCES: [[i8; 2]; 120] = [
498 [0, 1], [1, 0], [1, 1], [-1, 1], [0, 2], [2, 0], [1, 2], [-1, 2], [2, 1], [-2, 1],
499 [2, 2], [-2, 2], [0, 3], [3, 0], [1, 3], [-1, 3], [3, 1], [-3, 1], [2, 3], [-2, 3],
500 [3, 2], [-3, 2], [0, 4], [4, 0], [1, 4], [-1, 4], [4, 1], [-4, 1], [3, 3], [-3, 3],
501 [2, 4], [-2, 4], [4, 2], [-4, 2], [0, 5], [3, 4], [-3, 4], [4, 3], [-4, 3], [5, 0],
502 [1, 5], [-1, 5], [5, 1], [-5, 1], [2, 5], [-2, 5], [5, 2], [-5, 2], [4, 4], [-4, 4],
503 [3, 5], [-3, 5], [5, 3], [-5, 3], [0, 6], [6, 0], [1, 6], [-1, 6], [6, 1], [-6, 1],
504 [2, 6], [-2, 6], [6, 2], [-6, 2], [4, 5], [-4, 5], [5, 4], [-5, 4], [3, 6], [-3, 6],
505 [6, 3], [-6, 3], [0, 7], [7, 0], [1, 7], [-1, 7], [5, 5], [-5, 5], [7, 1], [-7, 1],
506 [4, 6], [-4, 6], [6, 4], [-6, 4], [2, 7], [-2, 7], [7, 2], [-7, 2], [3, 7], [-3, 7],
507 [7, 3], [-7, 3], [5, 6], [-5, 6], [6, 5], [-6, 5], [8, 0], [4, 7], [-4, 7], [7, 4],
508 [-7, 4], [8, 1], [8, 2], [6, 6], [-6, 6], [8, 3], [5, 7], [-5, 7], [7, 5], [-7, 5],
509 [8, 4], [6, 7], [-6, 7], [7, 6], [-7, 6], [8, 5], [7, 7], [-7, 7], [8, 6], [8, 7],
510 ];
511
512 let r;
513 if state.num_to_copy > 0 {
514 r = state.window[(state.copy_pos & 0xfffff) as usize];
515 state.copy_pos += 1;
516 state.num_to_copy -= 1;
517 } else {
518 let token = self.code.read_symbol(bitstream, cluster)?;
519 if token >= min_symbol {
520 if state.num_decoded == 0 {
521 tracing::error!("LZ77 repeat symbol encountered without decoding any symbols");
522 return Err(Error::UnexpectedLz77Repeat);
523 }
524
525 let lz_dist_cluster = self.lz_dist_cluster();
526
527 let num_to_copy =
528 self.read_uint_prefilled(bitstream, &state.lz_len_conf, token - min_symbol);
529 let Some(num_to_copy) = num_to_copy.checked_add(min_length) else {
530 tracing::error!(num_to_copy, min_length, "LZ77 num_to_copy overflow");
531 return Err(Error::InvalidLz77Symbol);
532 };
533 state.num_to_copy = num_to_copy;
534
535 let token = self.code.read_symbol(bitstream, lz_dist_cluster)?;
536 let distance = self.read_uint_prefilled(
537 bitstream,
538 &self.configs[lz_dist_cluster as usize],
539 token,
540 );
541 let distance = if dist_multiplier == 0 {
542 distance
543 } else if distance < 120 {
544 let [offset, dist] = SPECIAL_DISTANCES[distance as usize];
545 let dist = offset as i32 + dist_multiplier as i32 * dist as i32;
546 (dist - 1).max(0) as u32
547 } else {
548 distance - 120
549 };
550
551 let distance = (((1 << 20) - 1).min(distance) + 1).min(state.num_decoded);
552 state.copy_pos = state.num_decoded - distance;
553
554 r = state.window[(state.copy_pos & 0xfffff) as usize];
555 state.copy_pos += 1;
556 state.num_to_copy -= 1;
557 } else {
558 r = self.read_uint_prefilled(bitstream, &self.configs[cluster as usize], token);
559 }
560 }
561 let offset = (state.num_decoded & 0xfffff) as usize;
562 if state.window.len() <= offset {
563 state.window.push(r);
564 } else {
565 state.window[offset] = r;
566 }
567 state.num_decoded += 1;
568 Ok(r)
569 }
570
571 #[inline]
572 fn read_uint_prefilled(
573 &self,
574 bitstream: &mut Bitstream,
575 config: &IntegerConfig,
576 token: u32,
577 ) -> u32 {
578 let &IntegerConfig {
579 split_exponent,
580 split,
581 msb_in_token,
582 lsb_in_token,
583 ..
584 } = config;
585 if token < split {
586 return token;
587 }
588
589 let n = split_exponent - (msb_in_token + lsb_in_token)
590 + ((token - split) >> (msb_in_token + lsb_in_token));
591 let n = n & 31;
593 let rest_bits = bitstream.peek_bits_prefilled(n as usize) as u64;
594 bitstream.consume_bits(n as usize).ok();
595
596 let low_bits = token & ((1 << lsb_in_token) - 1);
597 let low_bits = low_bits as u64;
598 let token = token >> lsb_in_token;
599 let token = token & ((1 << msb_in_token) - 1);
600 let token = token | (1 << msb_in_token);
601 let token = token as u64;
602 let result = (((token << n) | rest_bits) << lsb_in_token) | low_bits;
603 result as u32
605 }
606
607 #[inline]
608 fn lz_dist_cluster(&self) -> u8 {
609 *self.clusters.last().unwrap()
610 }
611}
612
613#[derive(Debug, Clone)]
614enum Coder {
615 PrefixCode(Arc<Vec<prefix::Histogram>>),
616 Ans {
617 dist: Arc<Vec<ans::Histogram>>,
618 state: u32,
619 initial: bool,
620 },
621}
622
623impl Coder {
624 #[inline(always)]
625 fn read_symbol(&mut self, bitstream: &mut Bitstream, cluster: u8) -> CodingResult<u32> {
626 match self {
627 Self::PrefixCode(dist) => {
628 let dist = &dist[cluster as usize];
629 dist.read_symbol(bitstream)
630 }
631 Self::Ans {
632 dist,
633 state,
634 initial,
635 } => {
636 if *initial {
637 *state = bitstream.read_bits(32)?;
638 *initial = false;
639 }
640 let dist = &dist[cluster as usize];
641 dist.read_symbol(bitstream, state)
642 }
643 }
644 }
645
646 #[inline]
647 fn single_symbol(&self, cluster: u8) -> Option<u32> {
648 match self {
649 Self::PrefixCode(dist) => dist[cluster as usize].single_symbol(),
650 Self::Ans { dist, .. } => dist[cluster as usize].single_symbol(),
651 }
652 }
653
654 fn begin(&mut self, bitstream: &mut Bitstream) -> CodingResult<()> {
655 match self {
656 Self::PrefixCode(_) => Ok(()),
657 Self::Ans { state, initial, .. } => {
658 *state = bitstream.read_bits(32)?;
659 *initial = false;
660 Ok(())
661 }
662 }
663 }
664
665 fn finalize(&self) -> CodingResult<()> {
666 match *self {
667 Self::PrefixCode(_) => Ok(()),
668 Self::Ans { state, .. } => {
669 if state == 0x130000 {
670 Ok(())
671 } else {
672 Err(Error::InvalidAnsStream)
673 }
674 }
675 }
676 }
677}
678
679fn add_log2_ceil(x: u32) -> u32 {
680 if x >= 0x80000000 {
681 32
682 } else {
683 (x + 1).next_power_of_two().trailing_zeros()
684 }
685}
686
687pub fn read_clusters(bitstream: &mut Bitstream, num_dist: u32) -> CodingResult<(u32, Vec<u8>)> {
689 if num_dist == 1 {
690 return Ok((1, vec![0u8]));
691 }
692
693 let cluster = if bitstream.read_bool()? {
694 let nbits = bitstream.read_bits(2)? as usize;
696 (0..num_dist)
697 .map(|_| bitstream.read_bits(nbits).map(|b| b as u8))
698 .collect::<std::result::Result<Vec<_>, _>>()?
699 } else {
700 let use_mtf = bitstream.read_bool()?;
701 let mut decoder = if num_dist <= 2 {
702 Decoder::parse_assume_no_lz77(bitstream, 1)?
703 } else {
704 Decoder::parse(bitstream, 1)?
705 };
706 decoder.begin(bitstream)?;
707 let mut ret = (0..num_dist)
708 .map(|_| -> CodingResult<_> {
709 let b = decoder.read_varint(bitstream, 0)?;
710 u8::try_from(b).map_err(|_| Error::InvalidCluster(b))
711 })
712 .collect::<CodingResult<Vec<_>>>()?;
713 decoder.finalize()?;
714 if use_mtf {
715 let mut mtfmap = [0u8; 256];
716 for (idx, mtf) in mtfmap.iter_mut().enumerate() {
717 *mtf = idx as u8;
718 }
719 for cluster in &mut ret {
720 let idx = *cluster as usize;
721 *cluster = mtfmap[idx];
722 mtfmap.copy_within(0..idx, 1);
723 mtfmap[0] = *cluster;
724 }
725 }
726 ret
727 };
728
729 let num_clusters = *cluster.iter().max().unwrap() as u32 + 1;
730 let set = cluster
731 .iter()
732 .copied()
733 .collect::<std::collections::HashSet<_>>();
734 let num_expected_clusters = num_clusters;
735 let num_actual_clusters = set.len() as u32;
736 if num_actual_clusters != num_expected_clusters {
737 tracing::error!(
738 num_expected_clusters,
739 num_actual_clusters,
740 "distribution cluster has a hole"
741 );
742 Err(Error::ClusterHole {
743 num_expected_clusters,
744 num_actual_clusters,
745 })
746 } else {
747 Ok((num_clusters, cluster))
748 }
749}