1use std::sync::Arc;
8
9use jxl_bitstream::{Bitstream, U};
10
11mod ans;
12mod error;
13mod permutation;
14mod prefix;
15
16pub use error::Error;
17
18pub type CodingResult<T> = std::result::Result<T, Error>;
20
21pub use permutation::read_permutation;
22
23#[derive(Debug, Clone)]
25pub struct Decoder {
26 lz77: Lz77,
27 inner: DecoderInner,
28}
29
30impl Decoder {
31 pub fn parse(bitstream: &mut Bitstream, num_dist: u32) -> CodingResult<Self> {
34 let lz77 = Lz77::parse(bitstream)?;
35 let num_dist = if let Lz77::Disabled = &lz77 {
36 num_dist
37 } else {
38 num_dist + 1
39 };
40 let inner = DecoderInner::parse(bitstream, num_dist)?;
41 Ok(Self { lz77, inner })
42 }
43
44 fn parse_assume_no_lz77(bitstream: &mut Bitstream, num_dist: u32) -> CodingResult<Self> {
45 let lz77_enabled = bitstream.read_bool()?;
46 if lz77_enabled {
47 return Err(Error::Lz77NotAllowed);
48 }
49 let inner = DecoderInner::parse(bitstream, num_dist)?;
50 Ok(Self {
51 lz77: Lz77::Disabled,
52 inner,
53 })
54 }
55
56 #[inline]
58 pub fn read_varint(&mut self, bitstream: &mut Bitstream, ctx: u32) -> CodingResult<u32> {
59 self.read_varint_with_multiplier(bitstream, ctx, 0)
60 }
61
62 #[inline]
64 pub fn read_varint_with_multiplier(
65 &mut self,
66 bitstream: &mut Bitstream,
67 ctx: u32,
68 dist_multiplier: u32,
69 ) -> CodingResult<u32> {
70 let cluster = self.inner.clusters[ctx as usize];
71 self.read_varint_with_multiplier_clustered(bitstream, cluster, dist_multiplier)
72 }
73
74 #[inline(always)]
78 pub fn read_varint_with_multiplier_clustered(
79 &mut self,
80 bitstream: &mut Bitstream,
81 cluster: u8,
82 dist_multiplier: u32,
83 ) -> CodingResult<u32> {
84 if let Lz77::Enabled {
85 ref mut state,
86 min_symbol,
87 min_length,
88 } = self.lz77
89 {
90 self.inner.read_varint_with_multiplier_clustered_lz77(
91 bitstream,
92 cluster,
93 dist_multiplier,
94 state,
95 min_symbol,
96 min_length,
97 )
98 } else {
99 self.inner
100 .read_varint_with_multiplier_clustered(bitstream, cluster)
101 }
102 }
103
104 pub fn as_rle(&mut self) -> Option<DecoderRleMode<'_>> {
106 let &Lz77::Enabled {
107 ref state,
108 min_symbol,
109 min_length,
110 } = &self.lz77
111 else {
112 return None;
113 };
114 let lz_cluster = self.inner.lz_dist_cluster();
115 let lz_conf = &self.inner.configs[lz_cluster as usize];
116 let sym = self.inner.code.single_symbol(lz_cluster)?;
117 (sym == 1 && lz_conf.split_exponent == 0).then_some(DecoderRleMode {
118 inner: &mut self.inner,
119 min_symbol,
120 min_length,
121 len_config: state.lz_len_conf.clone(),
122 })
123 }
124
125 pub fn as_with_lz77(&mut self) -> Option<DecoderWithLz77<'_>> {
127 if let Lz77::Enabled {
128 ref mut state,
129 min_symbol,
130 min_length,
131 } = self.lz77
132 {
133 Some(DecoderWithLz77 {
134 inner: &mut self.inner,
135 state,
136 min_symbol,
137 min_length,
138 })
139 } else {
140 None
141 }
142 }
143
144 pub fn as_no_lz77(&mut self) -> Option<DecoderNoLz77<'_>> {
146 if let Lz77::Disabled = self.lz77 {
147 Some(DecoderNoLz77(&mut self.inner))
148 } else {
149 None
150 }
151 }
152
153 #[inline]
155 pub fn single_token(&self, cluster: u8) -> Option<u32> {
156 self.inner.single_token(cluster)
157 }
158
159 #[inline]
164 pub fn begin(&mut self, bitstream: &mut Bitstream) -> CodingResult<()> {
165 self.inner.code.begin(bitstream)
166 }
167
168 #[inline]
173 pub fn finalize(&self) -> CodingResult<()> {
174 self.inner.code.finalize()
175 }
176
177 #[inline]
179 pub fn cluster_map(&self) -> &[u8] {
180 &self.inner.clusters
181 }
182}
183
184#[derive(Debug)]
186pub struct DecoderRleMode<'dec> {
187 inner: &'dec mut DecoderInner,
188 min_symbol: u32,
189 min_length: u32,
190 len_config: IntegerConfig,
191}
192
193#[derive(Debug, Copy, Clone)]
195pub enum RleToken {
196 Value(u32),
198 Repeat(u32),
200}
201
202impl DecoderRleMode<'_> {
203 #[inline]
207 pub fn read_varint_clustered(
208 &mut self,
209 bitstream: &mut Bitstream,
210 cluster: u8,
211 ) -> CodingResult<RleToken> {
212 self.inner
213 .code
214 .read_symbol(bitstream, cluster)
215 .map(|token| {
216 if let Some(token) = token.checked_sub(self.min_symbol) {
217 RleToken::Repeat(
218 self.inner
219 .read_uint_prefilled(bitstream, &self.len_config, token)
220 + self.min_length,
221 )
222 } else {
223 RleToken::Value(self.inner.read_uint_prefilled(
224 bitstream,
225 &self.inner.configs[cluster as usize],
226 token,
227 ))
228 }
229 })
230 }
231
232 #[inline]
234 pub fn cluster_map(&self) -> &[u8] {
235 &self.inner.clusters
236 }
237}
238
239#[derive(Debug)]
241pub struct DecoderWithLz77<'dec> {
242 inner: &'dec mut DecoderInner,
243 state: &'dec mut Lz77State,
244 min_symbol: u32,
245 min_length: u32,
246}
247
248impl DecoderWithLz77<'_> {
249 #[inline]
253 pub fn read_varint_with_multiplier_clustered(
254 &mut self,
255 bitstream: &mut Bitstream,
256 cluster: u8,
257 dist_multiplier: u32,
258 ) -> CodingResult<u32> {
259 self.inner.read_varint_with_multiplier_clustered_lz77(
260 bitstream,
261 cluster,
262 dist_multiplier,
263 self.state,
264 self.min_symbol,
265 self.min_length,
266 )
267 }
268
269 #[inline]
271 pub fn cluster_map(&self) -> &[u8] {
272 &self.inner.clusters
273 }
274}
275
276#[derive(Debug)]
278pub struct DecoderNoLz77<'dec>(&'dec mut DecoderInner);
279
280impl DecoderNoLz77<'_> {
281 #[inline]
285 pub fn read_varint_clustered(
286 &mut self,
287 bitstream: &mut Bitstream,
288 cluster: u8,
289 ) -> CodingResult<u32> {
290 self.0
291 .read_varint_with_multiplier_clustered(bitstream, cluster)
292 }
293
294 #[inline]
296 pub fn single_token(&self, cluster: u8) -> Option<u32> {
297 self.0.single_token(cluster)
298 }
299
300 #[inline]
302 pub fn cluster_map(&self) -> &[u8] {
303 &self.0.clusters
304 }
305}
306
307#[derive(Debug, Clone)]
308enum Lz77 {
309 Disabled,
310 Enabled {
311 min_symbol: u32,
312 min_length: u32,
313 state: Lz77State,
314 },
315}
316
317impl Lz77 {
318 fn parse(bitstream: &mut Bitstream) -> CodingResult<Self> {
319 Ok(if bitstream.read_bool()? {
320 let min_symbol = bitstream.read_u32(224, 512, 4096, 8 + U(15))?;
322 let min_length = bitstream.read_u32(3, 4, 5 + U(2), 9 + U(8))?;
323 let lz_len_conf = IntegerConfig::parse(bitstream, 8)?;
324 Self::Enabled {
325 min_symbol,
326 min_length,
327 state: Lz77State::new(lz_len_conf),
328 }
329 } else {
330 Self::Disabled
331 })
332 }
333}
334
335#[derive(Clone)]
336struct Lz77State {
337 lz_len_conf: IntegerConfig,
338 window: Vec<u32>,
339 num_to_copy: u32,
340 copy_pos: u32,
341 num_decoded: u32,
342}
343
344impl std::fmt::Debug for Lz77State {
345 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
346 f.debug_struct("Lz77State")
347 .field("lz_len_conf", &self.lz_len_conf)
348 .field("num_to_copy", &self.num_to_copy)
349 .field("copy_pos", &self.copy_pos)
350 .field("num_decoded", &self.num_decoded)
351 .finish_non_exhaustive()
352 }
353}
354
355impl Lz77State {
356 fn new(lz_len_conf: IntegerConfig) -> Self {
357 Self {
358 lz_len_conf,
359 window: Vec::new(),
360 num_to_copy: 0,
361 copy_pos: 0,
362 num_decoded: 0,
363 }
364 }
365}
366
367#[derive(Debug, Clone)]
368struct IntegerConfig {
369 split_exponent: u32,
370 split: u32,
371 msb_in_token: u32,
372 lsb_in_token: u32,
373}
374
375impl IntegerConfig {
376 fn parse(bitstream: &mut Bitstream, log_alphabet_size: u32) -> CodingResult<Self> {
377 let split_exponent_bits = add_log2_ceil(log_alphabet_size);
378 let split_exponent = bitstream.read_bits(split_exponent_bits as usize)?;
379 let (msb_in_token, lsb_in_token) = if split_exponent != log_alphabet_size {
380 let msb_bits = add_log2_ceil(split_exponent) as usize;
381 let msb_in_token = bitstream.read_bits(msb_bits)?;
382 if msb_in_token > split_exponent {
383 return Err(Error::InvalidIntegerConfig {
384 split_exponent,
385 msb_in_token,
386 lsb_in_token: None,
387 });
388 }
389 let lsb_bits = add_log2_ceil(split_exponent - msb_in_token) as usize;
390 let lsb_in_token = bitstream.read_bits(lsb_bits)?;
391 (msb_in_token, lsb_in_token)
392 } else {
393 (0u32, 0u32)
394 };
395 if lsb_in_token + msb_in_token > split_exponent {
396 return Err(Error::InvalidIntegerConfig {
397 split_exponent,
398 msb_in_token,
399 lsb_in_token: Some(lsb_in_token),
400 });
401 }
402 Ok(Self {
403 split_exponent,
404 split: 1 << split_exponent,
405 msb_in_token,
406 lsb_in_token,
407 })
408 }
409}
410
411#[derive(Debug, Clone)]
412struct DecoderInner {
413 clusters: Vec<u8>, configs: Vec<IntegerConfig>, code: Coder,
416}
417
418impl DecoderInner {
419 fn parse(bitstream: &mut Bitstream, num_dist: u32) -> CodingResult<Self> {
420 let (num_clusters, clusters) = read_clusters(bitstream, num_dist)?;
421 let use_prefix_code = bitstream.read_bool()?;
422 let log_alphabet_size = if use_prefix_code {
423 15
424 } else {
425 bitstream.read_bits(2)? + 5
426 };
427 let configs = (0..num_clusters)
428 .map(|_| IntegerConfig::parse(bitstream, log_alphabet_size))
429 .collect::<CodingResult<Vec<_>>>()?;
430 let code = if use_prefix_code {
431 let counts = (0..num_clusters)
432 .map(|_| -> CodingResult<_> {
433 let count = if bitstream.read_bool()? {
434 let n = bitstream.read_bits(4)? as usize;
435 1 + (1 << n) + bitstream.read_bits(n)?
436 } else {
437 1
438 };
439 if count > 1 << 15 {
440 return Err(Error::InvalidPrefixHistogram);
441 }
442 Ok(count)
443 })
444 .collect::<CodingResult<Vec<_>>>()?;
445 let dist = counts
446 .into_iter()
447 .map(|count| prefix::Histogram::parse(bitstream, count))
448 .collect::<CodingResult<Vec<_>>>()?;
449 Coder::PrefixCode(Arc::new(dist))
450 } else {
451 let dist = (0..num_clusters)
452 .map(|_| ans::Histogram::parse(bitstream, log_alphabet_size))
453 .collect::<CodingResult<Vec<_>>>()?;
454 Coder::Ans {
455 dist: Arc::new(dist),
456 state: 0,
457 initial: true,
458 }
459 };
460 Ok(Self {
461 clusters,
462 configs,
463 code,
464 })
465 }
466
467 #[inline]
468 fn single_token(&self, cluster: u8) -> Option<u32> {
469 let single_symbol = self.code.single_symbol(cluster)?;
470 let IntegerConfig { split, .. } = self.configs[cluster as usize];
471 (single_symbol < split).then_some(single_symbol)
472 }
473
474 #[inline]
475 fn read_varint_with_multiplier_clustered(
476 &mut self,
477 bitstream: &mut Bitstream,
478 cluster: u8,
479 ) -> CodingResult<u32> {
480 let token = self.code.read_symbol(bitstream, cluster)?;
481 Ok(self.read_uint_prefilled(bitstream, &self.configs[cluster as usize], token))
482 }
483
484 fn read_varint_with_multiplier_clustered_lz77(
485 &mut self,
486 bitstream: &mut Bitstream,
487 cluster: u8,
488 dist_multiplier: u32,
489 state: &mut Lz77State,
490 min_symbol: u32,
491 min_length: u32,
492 ) -> CodingResult<u32> {
493 #[rustfmt::skip]
494 const SPECIAL_DISTANCES: [[i8; 2]; 120] = [
495 [0, 1], [1, 0], [1, 1], [-1, 1], [0, 2], [2, 0], [1, 2], [-1, 2], [2, 1], [-2, 1],
496 [2, 2], [-2, 2], [0, 3], [3, 0], [1, 3], [-1, 3], [3, 1], [-3, 1], [2, 3], [-2, 3],
497 [3, 2], [-3, 2], [0, 4], [4, 0], [1, 4], [-1, 4], [4, 1], [-4, 1], [3, 3], [-3, 3],
498 [2, 4], [-2, 4], [4, 2], [-4, 2], [0, 5], [3, 4], [-3, 4], [4, 3], [-4, 3], [5, 0],
499 [1, 5], [-1, 5], [5, 1], [-5, 1], [2, 5], [-2, 5], [5, 2], [-5, 2], [4, 4], [-4, 4],
500 [3, 5], [-3, 5], [5, 3], [-5, 3], [0, 6], [6, 0], [1, 6], [-1, 6], [6, 1], [-6, 1],
501 [2, 6], [-2, 6], [6, 2], [-6, 2], [4, 5], [-4, 5], [5, 4], [-5, 4], [3, 6], [-3, 6],
502 [6, 3], [-6, 3], [0, 7], [7, 0], [1, 7], [-1, 7], [5, 5], [-5, 5], [7, 1], [-7, 1],
503 [4, 6], [-4, 6], [6, 4], [-6, 4], [2, 7], [-2, 7], [7, 2], [-7, 2], [3, 7], [-3, 7],
504 [7, 3], [-7, 3], [5, 6], [-5, 6], [6, 5], [-6, 5], [8, 0], [4, 7], [-4, 7], [7, 4],
505 [-7, 4], [8, 1], [8, 2], [6, 6], [-6, 6], [8, 3], [5, 7], [-5, 7], [7, 5], [-7, 5],
506 [8, 4], [6, 7], [-6, 7], [7, 6], [-7, 6], [8, 5], [7, 7], [-7, 7], [8, 6], [8, 7],
507 ];
508
509 let r;
510 if state.num_to_copy > 0 {
511 r = state.window[(state.copy_pos & 0xfffff) as usize];
512 state.copy_pos += 1;
513 state.num_to_copy -= 1;
514 } else {
515 let token = self.code.read_symbol(bitstream, cluster)?;
516 if token >= min_symbol {
517 if state.num_decoded == 0 {
518 tracing::error!("LZ77 repeat symbol encountered without decoding any symbols");
519 return Err(Error::UnexpectedLz77Repeat);
520 }
521
522 let lz_dist_cluster = self.lz_dist_cluster();
523
524 let num_to_copy =
525 self.read_uint_prefilled(bitstream, &state.lz_len_conf, token - min_symbol);
526 let Some(num_to_copy) = num_to_copy.checked_add(min_length) else {
527 tracing::error!(num_to_copy, min_length, "LZ77 num_to_copy overflow");
528 return Err(Error::InvalidLz77Symbol);
529 };
530 state.num_to_copy = num_to_copy;
531
532 let token = self.code.read_symbol(bitstream, lz_dist_cluster)?;
533 let distance = self.read_uint_prefilled(
534 bitstream,
535 &self.configs[lz_dist_cluster as usize],
536 token,
537 );
538 let distance = if dist_multiplier == 0 {
539 distance
540 } else if distance < 120 {
541 let [offset, dist] = SPECIAL_DISTANCES[distance as usize];
542 let dist = offset as i32 + dist_multiplier as i32 * dist as i32;
543 (dist - 1).max(0) as u32
544 } else {
545 distance - 120
546 };
547
548 let distance = (((1 << 20) - 1).min(distance) + 1).min(state.num_decoded);
549 state.copy_pos = state.num_decoded - distance;
550
551 r = state.window[(state.copy_pos & 0xfffff) as usize];
552 state.copy_pos += 1;
553 state.num_to_copy -= 1;
554 } else {
555 r = self.read_uint_prefilled(bitstream, &self.configs[cluster as usize], token);
556 }
557 }
558 let offset = (state.num_decoded & 0xfffff) as usize;
559 if state.window.len() <= offset {
560 state.window.push(r);
561 } else {
562 state.window[offset] = r;
563 }
564 state.num_decoded += 1;
565 Ok(r)
566 }
567
568 #[inline]
569 fn read_uint_prefilled(
570 &self,
571 bitstream: &mut Bitstream,
572 config: &IntegerConfig,
573 token: u32,
574 ) -> u32 {
575 let &IntegerConfig {
576 split_exponent,
577 split,
578 msb_in_token,
579 lsb_in_token,
580 ..
581 } = config;
582 if token < split {
583 return token;
584 }
585
586 let n = split_exponent - (msb_in_token + lsb_in_token)
587 + ((token - split) >> (msb_in_token + lsb_in_token));
588 let n = n & 31;
590 let rest_bits = bitstream.peek_bits_prefilled(n as usize) as u64;
591 bitstream.consume_bits(n as usize).ok();
592
593 let low_bits = token & ((1 << lsb_in_token) - 1);
594 let low_bits = low_bits as u64;
595 let token = token >> lsb_in_token;
596 let token = token & ((1 << msb_in_token) - 1);
597 let token = token | (1 << msb_in_token);
598 let token = token as u64;
599 let result = (((token << n) | rest_bits) << lsb_in_token) | low_bits;
600 result as u32
602 }
603
604 #[inline]
605 fn lz_dist_cluster(&self) -> u8 {
606 *self.clusters.last().unwrap()
607 }
608}
609
610#[derive(Debug, Clone)]
611enum Coder {
612 PrefixCode(Arc<Vec<prefix::Histogram>>),
613 Ans {
614 dist: Arc<Vec<ans::Histogram>>,
615 state: u32,
616 initial: bool,
617 },
618}
619
620impl Coder {
621 #[inline(always)]
622 fn read_symbol(&mut self, bitstream: &mut Bitstream, cluster: u8) -> CodingResult<u32> {
623 match self {
624 Self::PrefixCode(dist) => {
625 let dist = &dist[cluster as usize];
626 dist.read_symbol(bitstream)
627 }
628 Self::Ans {
629 dist,
630 state,
631 initial,
632 } => {
633 if *initial {
634 *state = bitstream.read_bits(32)?;
635 *initial = false;
636 }
637 let dist = &dist[cluster as usize];
638 dist.read_symbol(bitstream, state)
639 }
640 }
641 }
642
643 #[inline]
644 fn single_symbol(&self, cluster: u8) -> Option<u32> {
645 match self {
646 Self::PrefixCode(dist) => dist[cluster as usize].single_symbol(),
647 Self::Ans { dist, .. } => dist[cluster as usize].single_symbol(),
648 }
649 }
650
651 fn begin(&mut self, bitstream: &mut Bitstream) -> CodingResult<()> {
652 match self {
653 Self::PrefixCode(_) => Ok(()),
654 Self::Ans { state, initial, .. } => {
655 *state = bitstream.read_bits(32)?;
656 *initial = false;
657 Ok(())
658 }
659 }
660 }
661
662 fn finalize(&self) -> CodingResult<()> {
663 match *self {
664 Self::PrefixCode(_) => Ok(()),
665 Self::Ans { state, .. } => {
666 if state == 0x130000 {
667 Ok(())
668 } else {
669 Err(Error::InvalidAnsStream)
670 }
671 }
672 }
673 }
674}
675
676fn add_log2_ceil(x: u32) -> u32 {
677 if x >= 0x80000000 {
678 32
679 } else {
680 (x + 1).next_power_of_two().trailing_zeros()
681 }
682}
683
684pub fn read_clusters(bitstream: &mut Bitstream, num_dist: u32) -> CodingResult<(u32, Vec<u8>)> {
686 if num_dist == 1 {
687 return Ok((1, vec![0u8]));
688 }
689
690 let cluster = if bitstream.read_bool()? {
691 let nbits = bitstream.read_bits(2)? as usize;
693 (0..num_dist)
694 .map(|_| bitstream.read_bits(nbits).map(|b| b as u8))
695 .collect::<std::result::Result<Vec<_>, _>>()?
696 } else {
697 let use_mtf = bitstream.read_bool()?;
698 let mut decoder = if num_dist <= 2 {
699 Decoder::parse_assume_no_lz77(bitstream, 1)?
700 } else {
701 Decoder::parse(bitstream, 1)?
702 };
703 decoder.begin(bitstream)?;
704 let mut ret = (0..num_dist)
705 .map(|_| -> CodingResult<_> {
706 let b = decoder.read_varint(bitstream, 0)?;
707 u8::try_from(b).map_err(|_| Error::InvalidCluster(b))
708 })
709 .collect::<CodingResult<Vec<_>>>()?;
710 decoder.finalize()?;
711 if use_mtf {
712 let mut mtfmap = [0u8; 256];
713 for (idx, mtf) in mtfmap.iter_mut().enumerate() {
714 *mtf = idx as u8;
715 }
716 for cluster in &mut ret {
717 let idx = *cluster as usize;
718 *cluster = mtfmap[idx];
719 mtfmap.copy_within(0..idx, 1);
720 mtfmap[0] = *cluster;
721 }
722 }
723 ret
724 };
725
726 let num_clusters = *cluster.iter().max().unwrap() as u32 + 1;
727 let set = cluster
728 .iter()
729 .copied()
730 .collect::<std::collections::HashSet<_>>();
731 let num_expected_clusters = num_clusters;
732 let num_actual_clusters = set.len() as u32;
733 if num_actual_clusters != num_expected_clusters {
734 tracing::error!(
735 num_expected_clusters,
736 num_actual_clusters,
737 "distribution cluster has a hole"
738 );
739 Err(Error::ClusterHole {
740 num_expected_clusters,
741 num_actual_clusters,
742 })
743 } else {
744 Ok((num_clusters, cluster))
745 }
746}