1use std::collections::HashMap;
2
3use jxl_bitstream::Bitstream;
4use jxl_coding::{Decoder, DecoderRleMode, RleToken};
5use jxl_grid::{AlignedGrid, AllocTracker, MutableSubgrid};
6
7use crate::{
8 MaConfig, ModularChannelInfo, ModularChannels, ModularHeader, Result,
9 ma::{FlatMaTree, MaTreeLeafClustered, SimpleMaTable},
10 predictor::{Predictor, PredictorState, Properties, WpHeader},
11 sample::Sample,
12};
13
14#[derive(Debug)]
15pub enum TransformedGrid<'dest, S: Sample> {
16 Single(MutableSubgrid<'dest, S>),
17 Merged {
18 leader: MutableSubgrid<'dest, S>,
19 members: Vec<TransformedGrid<'dest, S>>,
20 },
21}
22
23impl<'dest, S: Sample> From<MutableSubgrid<'dest, S>> for TransformedGrid<'dest, S> {
24 fn from(value: MutableSubgrid<'dest, S>) -> Self {
25 Self::Single(value)
26 }
27}
28
29impl<S: Sample> TransformedGrid<'_, S> {
30 fn reborrow(&mut self) -> TransformedGrid<S> {
31 match self {
32 TransformedGrid::Single(g) => TransformedGrid::Single(g.split_horizontal(0).1),
33 TransformedGrid::Merged { leader, .. } => {
34 TransformedGrid::Single(leader.split_horizontal(0).1)
35 }
36 }
37 }
38}
39
40impl<'dest, S: Sample> TransformedGrid<'dest, S> {
41 pub(crate) fn grid(&self) -> &MutableSubgrid<'dest, S> {
42 match self {
43 Self::Single(g) => g,
44 Self::Merged { leader, .. } => leader,
45 }
46 }
47
48 pub(crate) fn grid_mut(&mut self) -> &mut MutableSubgrid<'dest, S> {
49 match self {
50 Self::Single(g) => g,
51 Self::Merged { leader, .. } => leader,
52 }
53 }
54
55 pub(crate) fn merge(&mut self, members: Vec<TransformedGrid<'dest, S>>) {
56 if members.is_empty() {
57 return;
58 }
59
60 match self {
61 Self::Single(leader) => {
62 let tmp = MutableSubgrid::empty();
63 let leader = std::mem::replace(leader, tmp);
64 *self = Self::Merged { leader, members };
65 }
66 Self::Merged {
67 members: original_members,
68 ..
69 } => {
70 original_members.extend(members);
71 }
72 }
73 }
74
75 pub(crate) fn unmerge(&mut self, count: usize) -> Vec<TransformedGrid<'dest, S>> {
76 if count == 0 {
77 return Vec::new();
78 }
79
80 match self {
81 Self::Single(_) => panic!("cannot unmerge TransformedGrid::Single"),
82 Self::Merged { leader, members } => {
83 let len = members.len();
84 let members = members.drain((len - count)..).collect();
85 if len == count {
86 let tmp = MutableSubgrid::empty();
87 let leader = std::mem::replace(leader, tmp);
88 *self = Self::Single(leader);
89 }
90 members
91 }
92 }
93 }
94}
95
96#[derive(Debug)]
97pub struct ModularImageDestination<S: Sample> {
98 header: ModularHeader,
99 ma_ctx: MaConfig,
100 group_dim: u32,
101 bit_depth: u32,
102 channels: ModularChannels,
103 meta_channels: Vec<AlignedGrid<S>>,
104 image_channels: Vec<AlignedGrid<S>>,
105}
106
107impl<S: Sample> ModularImageDestination<S> {
108 pub(crate) fn new(
109 header: ModularHeader,
110 ma_ctx: MaConfig,
111 group_dim: u32,
112 bit_depth: u32,
113 channels: ModularChannels,
114 tracker: Option<&AllocTracker>,
115 ) -> Result<Self> {
116 let mut meta_channels = Vec::new();
117 for tr in &header.transform {
118 tr.prepare_meta_channels(&mut meta_channels, tracker)?;
119 }
120
121 let image_channels = channels
122 .info
123 .iter()
124 .map(|ch| {
125 AlignedGrid::with_alloc_tracker(ch.width as usize, ch.height as usize, tracker)
126 })
127 .collect::<std::result::Result<_, _>>()?;
128
129 Ok(Self {
130 header,
131 ma_ctx,
132 group_dim,
133 bit_depth,
134 channels,
135 meta_channels,
136 image_channels,
137 })
138 }
139
140 pub fn try_clone(&self) -> Result<Self> {
141 Ok(Self {
142 header: self.header.clone(),
143 ma_ctx: self.ma_ctx.clone(),
144 group_dim: self.group_dim,
145 bit_depth: self.bit_depth,
146 channels: self.channels.clone(),
147 meta_channels: self
148 .meta_channels
149 .iter()
150 .map(|x| x.try_clone())
151 .collect::<std::result::Result<_, _>>()?,
152 image_channels: self
153 .image_channels
154 .iter()
155 .map(|x| x.try_clone())
156 .collect::<std::result::Result<_, _>>()?,
157 })
158 }
159
160 pub fn image_channels(&self) -> &[AlignedGrid<S>] {
161 &self.image_channels
162 }
163
164 pub fn into_image_channels(self) -> Vec<AlignedGrid<S>> {
165 self.image_channels
166 }
167
168 pub fn into_image_channels_with_info(
169 self,
170 ) -> impl Iterator<Item = (AlignedGrid<S>, ModularChannelInfo)> {
171 self.image_channels.into_iter().zip(self.channels.info)
172 }
173
174 pub fn has_palette(&self) -> bool {
175 self.header.transform.iter().any(|tr| tr.is_palette())
176 }
177
178 pub fn has_squeeze(&self) -> bool {
179 self.header.transform.iter().any(|tr| tr.is_squeeze())
180 }
181}
182
183impl<S: Sample> ModularImageDestination<S> {
184 pub fn prepare_gmodular(&mut self) -> Result<TransformedModularSubimage<S>> {
185 assert_ne!(self.group_dim, 0);
186
187 let group_dim = self.group_dim;
188 let subimage = self.prepare_subimage()?;
189 let (channel_info, grids): (Vec<_>, Vec<_>) = subimage
190 .channel_info
191 .into_iter()
192 .zip(subimage.grid)
193 .enumerate()
194 .take_while(|&(i, (ref info, _))| {
195 i < subimage.nb_meta_channels
196 || (info.width <= group_dim && info.height <= group_dim)
197 })
198 .map(|(_, x)| x)
199 .unzip();
200 let channel_indices = (0..channel_info.len()).collect();
201 Ok(TransformedModularSubimage {
202 channel_info,
203 channel_indices,
204 grid: grids,
205 ..subimage
206 })
207 }
208
209 pub fn prepare_groups(
210 &mut self,
211 pass_shifts: &std::collections::BTreeMap<u32, (i32, i32)>,
212 ) -> Result<TransformedGlobalModular<S>> {
213 assert_ne!(self.group_dim, 0);
214
215 let num_passes = *pass_shifts.last_key_value().unwrap().0 as usize + 1;
216
217 let group_dim = self.group_dim;
218 let group_dim_shift = group_dim.trailing_zeros();
219 let bit_depth = self.bit_depth;
220 let subimage = self.prepare_subimage()?;
221 let it = subimage
222 .channel_info
223 .into_iter()
224 .zip(subimage.grid)
225 .enumerate()
226 .skip_while(|&(i, (ref info, _))| {
227 i < subimage.nb_meta_channels
228 || (info.width <= group_dim && info.height <= group_dim)
229 });
230
231 let mut lf_groups = Vec::new();
232 let mut pass_groups = Vec::with_capacity(num_passes);
233 pass_groups.resize_with(num_passes, Vec::new);
234 for (i, (info, grid)) in it {
235 let ModularChannelInfo {
236 original_width,
237 original_height,
238 hshift,
239 vshift,
240 original_shift,
241 ..
242 } = info;
243 assert!(hshift >= 0 && vshift >= 0);
244
245 let grid = match grid {
246 TransformedGrid::Single(g) => g,
247 TransformedGrid::Merged { leader, .. } => leader,
248 };
249 tracing::trace!(
250 i,
251 width = grid.width(),
252 height = grid.height(),
253 hshift,
254 vshift
255 );
256
257 let (groups, grids) = if hshift < 3 || vshift < 3 {
258 let shift = hshift.min(vshift); let pass_idx = *pass_shifts
260 .iter()
261 .find(|&(_, &(minshift, maxshift))| (minshift..maxshift).contains(&shift))
262 .unwrap()
263 .0;
264 let pass_idx = pass_idx as usize;
265
266 let group_width = group_dim >> hshift;
267 let group_height = group_dim >> vshift;
268 if group_width == 0 || group_height == 0 {
269 tracing::error!(
270 group_dim,
271 hshift,
272 vshift,
273 "Channel shift value too large after transform"
274 );
275 return Err(crate::Error::InvalidSqueezeParams);
276 }
277
278 let grids = grid.into_groups_with_fixed_count(
279 group_width as usize,
280 group_height as usize,
281 (original_width + group_dim - 1) as usize >> group_dim_shift,
282 (original_height + group_dim - 1) as usize >> group_dim_shift,
283 );
284 (&mut pass_groups[pass_idx], grids)
285 } else {
286 let lf_group_width = group_dim >> (hshift - 3);
288 let lf_group_height = group_dim >> (vshift - 3);
289 if lf_group_width == 0 || lf_group_height == 0 {
290 tracing::error!(
291 group_dim,
292 hshift,
293 vshift,
294 "Channel shift value too large after transform"
295 );
296 return Err(crate::Error::InvalidSqueezeParams);
297 }
298 let grids = grid.into_groups_with_fixed_count(
299 lf_group_width as usize,
300 lf_group_height as usize,
301 (original_width + (group_dim << 3) - 1) as usize >> (group_dim_shift + 3),
302 (original_height + (group_dim << 3) - 1) as usize >> (group_dim_shift + 3),
303 );
304 (&mut lf_groups, grids)
305 };
306
307 if groups.is_empty() {
308 groups.resize_with(grids.len(), || {
309 TransformedModularSubimage::empty(&subimage.header, &subimage.ma_ctx, bit_depth)
310 });
311 } else if groups.len() != grids.len() {
312 panic!();
313 }
314
315 for (subimage, grid) in groups.iter_mut().zip(grids) {
316 let width = grid.width() as u32;
317 let height = grid.height() as u32;
318 if width == 0 || height == 0 {
319 continue;
320 }
321
322 subimage.channel_info.push(ModularChannelInfo {
323 width,
324 height,
325 original_width: width << hshift,
326 original_height: height << vshift,
327 hshift,
328 vshift,
329 original_shift,
330 });
331 subimage.channel_indices.push(i);
332 subimage.grid.push(grid.into());
333 subimage.partial = true;
334 }
335 }
336
337 Ok(TransformedGlobalModular {
338 lf_groups,
339 pass_groups,
340 })
341 }
342
343 pub fn prepare_subimage(&mut self) -> Result<TransformedModularSubimage<S>> {
344 let mut channels = self.channels.clone();
345 let mut meta_channel_grids = self
346 .meta_channels
347 .iter_mut()
348 .map(MutableSubgrid::from)
349 .collect::<Vec<_>>();
350 let mut grids = self
351 .image_channels
352 .iter_mut()
353 .map(|g| g.as_subgrid_mut().into())
354 .collect::<Vec<_>>();
355 for tr in &self.header.transform {
356 tr.transform_channels(&mut channels, &mut meta_channel_grids, &mut grids)?;
357 }
358
359 let channel_info = channels.info;
360 let channel_indices = (0..channel_info.len()).collect();
361 Ok(TransformedModularSubimage {
362 header: self.header.clone(),
363 ma_ctx: self.ma_ctx.clone(),
364 bit_depth: self.bit_depth,
365 nb_meta_channels: channels.nb_meta_channels as usize,
366 channel_info,
367 channel_indices,
368 grid: grids,
369 partial: true,
370 })
371 }
372}
373
374#[derive(Debug)]
375pub struct TransformedGlobalModular<'dest, S: Sample> {
376 pub lf_groups: Vec<TransformedModularSubimage<'dest, S>>,
377 pub pass_groups: Vec<Vec<TransformedModularSubimage<'dest, S>>>,
378}
379
380#[derive(Debug)]
381pub struct TransformedModularSubimage<'dest, S: Sample> {
382 header: ModularHeader,
383 ma_ctx: MaConfig,
384 bit_depth: u32,
385 nb_meta_channels: usize,
386 channel_info: Vec<ModularChannelInfo>,
387 channel_indices: Vec<usize>,
388 grid: Vec<TransformedGrid<'dest, S>>,
389 partial: bool,
390}
391
392impl<S: Sample> TransformedModularSubimage<'_, S> {
393 fn empty(header: &ModularHeader, ma_ctx: &MaConfig, bit_depth: u32) -> Self {
394 Self {
395 header: header.clone(),
396 ma_ctx: ma_ctx.clone(),
397 bit_depth,
398 nb_meta_channels: 0,
399 channel_info: Vec::new(),
400 channel_indices: Vec::new(),
401 grid: Vec::new(),
402 partial: false,
403 }
404 }
405}
406
407impl<'dest, S: Sample> TransformedModularSubimage<'dest, S> {
408 pub fn is_empty(&self) -> bool {
409 self.channel_info.is_empty()
410 }
411
412 pub fn is_partial(&self) -> bool {
413 self.partial
414 }
415
416 pub fn recursive(
417 self,
418 bitstream: &mut Bitstream,
419 global_ma_config: Option<&MaConfig>,
420 tracker: Option<&AllocTracker>,
421 ) -> Result<RecursiveModularImage<'dest, S>> {
422 let channels = crate::ModularChannels {
423 info: self.channel_info,
424 nb_meta_channels: 0,
425 };
426 let (header, ma_ctx) = crate::read_and_validate_local_modular_header(
427 bitstream,
428 &channels,
429 global_ma_config,
430 tracker,
431 )?;
432
433 let mut image = RecursiveModularImage {
434 header,
435 ma_ctx,
436 bit_depth: self.bit_depth,
437 channels,
438 meta_channels: Vec::new(),
439 image_channels: self.grid,
440 };
441 for tr in &image.header.transform {
442 tr.prepare_meta_channels(&mut image.meta_channels, tracker)?;
443 }
444 Ok(image)
445 }
446
447 pub fn finish(mut self, pool: &jxl_threadpool::JxlThreadPool) -> bool {
448 for tr in self.header.transform.iter().rev() {
449 tr.inverse(&mut self.grid, self.bit_depth, pool);
450 }
451 !self.partial
452 }
453}
454
455impl<S: Sample> TransformedModularSubimage<'_, S> {
456 fn decode_inner(&mut self, bitstream: &mut Bitstream, stream_index: u32) -> Result<()> {
457 let span = tracing::span!(tracing::Level::TRACE, "decode channels", stream_index);
458 let _guard = span.enter();
459
460 let dist_multiplier = self
461 .channel_info
462 .iter()
463 .map(|info| info.width)
464 .max()
465 .unwrap_or(0);
466
467 let mut decoder = self.ma_ctx.decoder().clone();
468 decoder.begin(bitstream)?;
469
470 let mut ma_tree_list = Vec::with_capacity(self.channel_info.len());
471 for (i, info) in self.channel_info.iter().enumerate() {
472 if info.width == 0 || info.height == 0 {
473 ma_tree_list.push(None);
474 continue;
475 }
476
477 let filtered_prev_len = self.channel_info[..i]
478 .iter()
479 .filter(|prev_info| {
480 info.width == prev_info.width
481 && info.height == prev_info.height
482 && info.hshift == prev_info.hshift
483 && info.vshift == prev_info.vshift
484 })
485 .count();
486
487 let ma_tree =
488 self.ma_ctx
489 .make_flat_tree(i as u32, stream_index, filtered_prev_len as u32);
490 ma_tree_list.push(Some(ma_tree));
491 }
492
493 if let Some(mut rle_decoder) = decoder.as_rle() {
494 let is_fast_lossless = ma_tree_list.iter().all(|ma_tree| {
495 ma_tree
496 .as_ref()
497 .map(|ma_tree| {
498 matches!(
499 ma_tree.single_node(),
500 Some(MaTreeLeafClustered {
501 predictor: Predictor::Gradient,
502 offset: 0,
503 multiplier: 1,
504 ..
505 })
506 )
507 })
508 .unwrap_or(true)
509 });
510
511 if is_fast_lossless {
512 tracing::trace!("libjxl fast-lossless");
513 let mut rle_state = RleState::<S>::new();
514
515 for (ma_tree, grid) in ma_tree_list.into_iter().zip(&mut self.grid) {
516 let Some(ma_tree) = ma_tree else {
517 continue;
518 };
519
520 let node = ma_tree.single_node().unwrap();
521 let cluster = node.cluster;
522 decode_fast_lossless(
523 bitstream,
524 &mut rle_decoder,
525 &mut rle_state,
526 cluster,
527 grid.grid_mut(),
528 );
529 }
530
531 rle_state.check_error()?;
532 return Ok(());
534 }
535 }
536
537 let wp_header = &self.header.wp_params;
538 let mut predictor = PredictorState::new();
539 let mut prev_map = HashMap::new();
540 for ((info, ma_tree), grid) in self
541 .channel_info
542 .iter()
543 .zip(ma_tree_list)
544 .zip(&mut self.grid)
545 {
546 let Some(ma_tree) = ma_tree else {
547 continue;
548 };
549 let key = (info.width, info.height, info.hshift, info.vshift);
550
551 let filtered_prev = prev_map.entry(key).or_insert_with(Vec::new);
552
553 if let Some(node) = ma_tree.single_node() {
554 decode_single_node(
555 bitstream,
556 &mut decoder,
557 dist_multiplier,
558 &mut predictor,
559 wp_header,
560 grid.grid_mut(),
561 node,
562 )?;
563 } else if let Some(table) = ma_tree.simple_table() {
564 decode_simple_table(
565 bitstream,
566 &mut decoder,
567 dist_multiplier,
568 &mut predictor,
569 wp_header,
570 grid.grid_mut(),
571 &table,
572 )?;
573 } else {
574 let grid = grid.grid_mut();
575 let filtered_prev = &filtered_prev[..ma_tree.max_prev_channel_depth()];
576 let wp_header = ma_tree.need_self_correcting().then_some(wp_header);
577 predictor.reset(grid.width() as u32, filtered_prev, wp_header);
578 decode_slow(
579 bitstream,
580 &mut decoder,
581 dist_multiplier,
582 &ma_tree,
583 &mut predictor,
584 grid,
585 )?;
586 }
587
588 filtered_prev.insert(0, grid.grid());
589 }
590
591 decoder.finalize()?;
592 Ok(())
593 }
594
595 pub fn decode(
596 &mut self,
597 bitstream: &mut Bitstream,
598 stream_index: u32,
599 allow_partial: bool,
600 ) -> Result<()> {
601 match self.decode_inner(bitstream, stream_index) {
602 Err(e) if e.unexpected_eof() && allow_partial => {
603 tracing::debug!("Partially decoded Modular image");
604 }
605 Err(e) => return Err(e),
606 Ok(_) => {
607 self.partial = false;
608 }
609 }
610 Ok(())
611 }
612}
613
614#[derive(Debug)]
615pub struct RecursiveModularImage<'dest, S: Sample> {
616 header: ModularHeader,
617 ma_ctx: MaConfig,
618 bit_depth: u32,
619 channels: ModularChannels,
620 meta_channels: Vec<AlignedGrid<S>>,
621 image_channels: Vec<TransformedGrid<'dest, S>>,
622}
623
624impl<S: Sample> RecursiveModularImage<'_, S> {
625 pub fn prepare_subimage(&mut self) -> Result<TransformedModularSubimage<S>> {
626 let mut channels = self.channels.clone();
627 let mut meta_channel_grids = self
628 .meta_channels
629 .iter_mut()
630 .map(|g| {
631 let width = g.width();
632 let height = g.height();
633 MutableSubgrid::from_buf(g.buf_mut(), width, height, width)
634 })
635 .collect::<Vec<_>>();
636 let mut grids = self
637 .image_channels
638 .iter_mut()
639 .map(|g| g.reborrow())
640 .collect();
641 for tr in &self.header.transform {
642 tr.transform_channels(&mut channels, &mut meta_channel_grids, &mut grids)?;
643 }
644
645 let channel_info = channels.info;
646 let channel_indices = (0..channel_info.len()).collect();
647 Ok(TransformedModularSubimage {
648 header: self.header.clone(),
649 ma_ctx: self.ma_ctx.clone(),
650 bit_depth: self.bit_depth,
651 nb_meta_channels: channels.nb_meta_channels as usize,
652 channel_info,
653 channel_indices,
654 grid: grids,
655 partial: true,
656 })
657 }
658}
659
660struct RleState<S: Sample> {
661 value: S,
662 repeat: u32,
663 error: Option<Box<jxl_coding::Error>>,
664}
665
666impl<S: Sample> RleState<S> {
667 #[inline]
668 fn new() -> Self {
669 Self {
670 value: S::default(),
671 repeat: 0,
672 error: None,
673 }
674 }
675
676 #[inline(always)]
677 fn decode(
678 &mut self,
679 bitstream: &mut Bitstream,
680 decoder: &mut DecoderRleMode,
681 cluster: u8,
682 ) -> S {
683 if self.repeat == 0 {
684 let result = decoder.read_varint_clustered(bitstream, cluster);
685 match result {
686 Ok(RleToken::Value(v)) => {
687 self.value = S::unpack_signed_u32(v);
688 self.repeat = 1;
689 }
690 Ok(RleToken::Repeat(len)) => {
691 self.repeat = len;
692 }
693 Err(e) if self.error.is_none() => {
694 self.error = Some(Box::new(e));
695 }
696 _ => {}
697 }
698 }
699
700 self.repeat = self.repeat.wrapping_sub(1);
701 self.value
702 }
703
704 #[inline]
705 fn check_error(&mut self) -> Result<()> {
706 let error = self.error.take();
707 if let Some(error) = error {
708 let error = *error;
709 Err(error.into())
710 } else {
711 Ok(())
712 }
713 }
714}
715
716fn decode_single_node<S: Sample>(
717 bitstream: &mut Bitstream,
718 decoder: &mut Decoder,
719 dist_multiplier: u32,
720 predictor_state: &mut PredictorState<S>,
721 wp_header: &WpHeader,
722 grid: &mut MutableSubgrid<S>,
723 node: &MaTreeLeafClustered,
724) -> Result<()> {
725 let &MaTreeLeafClustered {
726 cluster,
727 predictor,
728 offset,
729 multiplier,
730 } = node;
731 tracing::trace!(cluster, ?predictor, "Single MA tree node");
732
733 let height = grid.height();
734 let single_token = decoder.single_token(cluster);
735 match (predictor, single_token) {
736 (Predictor::Zero, Some(token)) => {
737 tracing::trace!("Single token in cluster, Zero predictor: hyper fast path");
738 let value = S::unpack_signed_u32(token).wrapping_muladd_i32(multiplier as i32, offset);
739 for y in 0..height {
740 grid.get_row_mut(y).fill(value);
741 }
742 Ok(())
743 }
744 (Predictor::Zero, None) => {
745 tracing::trace!("Zero predictor: fast path");
746 for y in 0..height {
747 let row = grid.get_row_mut(y);
748 for out in row {
749 let token = decoder.read_varint_with_multiplier_clustered(
750 bitstream,
751 cluster,
752 dist_multiplier,
753 )?;
754 *out =
755 S::unpack_signed_u32(token).wrapping_muladd_i32(multiplier as i32, offset);
756 }
757 }
758 Ok(())
759 }
760 (Predictor::Gradient, _) if offset == 0 && multiplier == 1 => {
761 tracing::trace!("Simple gradient: quite fast path");
762 decode_simple_grad(bitstream, decoder, cluster, dist_multiplier, grid)
763 }
764 _ => {
765 let wp_header = (predictor == Predictor::SelfCorrecting).then_some(wp_header);
766 predictor_state.reset(grid.width() as u32, &[], wp_header);
767 decode_single_node_slow(
768 bitstream,
769 decoder,
770 dist_multiplier,
771 node,
772 predictor_state,
773 grid,
774 )
775 }
776 }
777}
778
779#[inline(never)]
780fn decode_fast_lossless<S: Sample>(
781 bitstream: &mut Bitstream,
782 decoder: &mut DecoderRleMode,
783 rle_state: &mut RleState<S>,
784 cluster: u8,
785 grid: &mut MutableSubgrid<S>,
786) {
787 let height = grid.height();
788
789 {
790 let mut w = S::default();
791 let out_row = grid.get_row_mut(0);
792 for out in &mut *out_row {
793 let token = rle_state.decode(bitstream, decoder, cluster);
794 w = w.add(token);
795 *out = w;
796 }
797 }
798
799 for y in 1..height {
800 let (u, mut d) = grid.split_vertical(y);
801 let prev_row = u.get_row(y - 1);
802 let out_row = d.get_row_mut(0);
803
804 let token = rle_state.decode(bitstream, decoder, cluster);
805 let mut w = token.add(prev_row[0]);
806 out_row[0] = w;
807
808 for (window, out) in prev_row.windows(2).zip(&mut out_row[1..]) {
809 let nw = window[0];
810 let n = window[1];
811 let pred = S::grad_clamped(n, w, nw);
812
813 let token = rle_state.decode(bitstream, decoder, cluster);
814 w = token.add(pred);
815 *out = w;
816 }
817 }
818}
819
820#[inline(never)]
821fn decode_simple_grad<S: Sample>(
822 bitstream: &mut Bitstream,
823 decoder: &mut Decoder,
824 cluster: u8,
825 dist_multiplier: u32,
826 grid: &mut MutableSubgrid<S>,
827) -> Result<()> {
828 let width = grid.width();
829 let height = grid.height();
830
831 {
832 let mut w = S::default();
833 let out_row = grid.get_row_mut(0);
834 for out in out_row[..width].iter_mut() {
835 let token = decoder.read_varint_with_multiplier_clustered(
836 bitstream,
837 cluster,
838 dist_multiplier,
839 )?;
840 w = S::unpack_signed_u32(token).add(w);
841 *out = w;
842 }
843 }
844
845 for y in 1..height {
846 let (u, mut d) = grid.split_vertical(y);
847 let prev_row = u.get_row(y - 1);
848 let out_row = d.get_row_mut(0);
849
850 let token =
851 decoder.read_varint_with_multiplier_clustered(bitstream, cluster, dist_multiplier)?;
852 let mut w = S::unpack_signed_u32(token).add(prev_row[0]);
853 out_row[0] = w;
854
855 for (window, out) in prev_row.windows(2).zip(&mut out_row[1..]) {
856 let nw = window[0];
857 let n = window[1];
858 let pred = S::grad_clamped(n, w, nw);
859
860 let token = decoder.read_varint_with_multiplier_clustered(
861 bitstream,
862 cluster,
863 dist_multiplier,
864 )?;
865 let value = S::unpack_signed_u32(token).add(pred);
866 *out = value;
867 w = value;
868 }
869 }
870
871 Ok(())
872}
873
874#[inline(always)]
875fn decode_one<S: Sample, const EDGE: bool>(
876 bitstream: &mut Bitstream,
877 decoder: &mut Decoder,
878 dist_multiplier: u32,
879 leaf: &MaTreeLeafClustered,
880 properties: &Properties<S>,
881) -> Result<S> {
882 let diff = S::unpack_signed_u32(decoder.read_varint_with_multiplier_clustered(
883 bitstream,
884 leaf.cluster,
885 dist_multiplier,
886 )?);
887 let diff = diff.wrapping_muladd_i32(leaf.multiplier as i32, leaf.offset);
888 let predictor = leaf.predictor;
889 let sample_prediction = predictor.predict::<_, EDGE>(properties);
890 Ok(diff.add(S::from_i32(sample_prediction)))
891}
892
893#[inline(never)]
894fn decode_single_node_slow<S: Sample>(
895 bitstream: &mut Bitstream,
896 decoder: &mut Decoder,
897 dist_multiplier: u32,
898 leaf: &MaTreeLeafClustered,
899 predictor: &mut PredictorState<S>,
900 grid: &mut MutableSubgrid<S>,
901) -> Result<()> {
902 let height = grid.height();
903 for y in 0..2usize.min(height) {
904 let row = grid.get_row_mut(y);
905
906 for out in row.iter_mut() {
907 let properties = predictor.properties::<true>();
908 let true_value =
909 decode_one::<_, true>(bitstream, decoder, dist_multiplier, leaf, &properties)?;
910 *out = true_value;
911 properties.record(true_value.to_i32());
912 }
913 }
914
915 for y in 2..height {
916 let row = grid.get_row_mut(y);
917 let (row_left, row_middle, row_right) = if row.len() <= 4 {
918 (row, [].as_mut(), [].as_mut())
919 } else {
920 let (l, m) = row.split_at_mut(2);
921 let (m, r) = m.split_at_mut(m.len() - 2);
922 (l, m, r)
923 };
924
925 for out in row_left {
926 let properties = predictor.properties::<true>();
927 let true_value =
928 decode_one::<_, true>(bitstream, decoder, dist_multiplier, leaf, &properties)?;
929 *out = true_value;
930 properties.record(true_value.to_i32());
931 }
932 for out in row_middle {
933 let properties = predictor.properties::<false>();
934 let true_value =
935 decode_one::<_, false>(bitstream, decoder, dist_multiplier, leaf, &properties)?;
936 *out = true_value;
937 properties.record(true_value.to_i32());
938 }
939 for out in row_right {
940 let properties = predictor.properties::<true>();
941 let true_value =
942 decode_one::<_, true>(bitstream, decoder, dist_multiplier, leaf, &properties)?;
943 *out = true_value;
944 properties.record(true_value.to_i32());
945 }
946 }
947
948 Ok(())
949}
950
951fn decode_simple_table<S: Sample>(
952 bitstream: &mut Bitstream,
953 decoder: &mut Decoder,
954 dist_multiplier: u32,
955 predictor_state: &mut PredictorState<S>,
956 wp_header: &WpHeader,
957 grid: &mut MutableSubgrid<S>,
958 table: &SimpleMaTable,
959) -> Result<()> {
960 let &SimpleMaTable {
961 decision_prop,
962 value_base,
963 predictor,
964 offset,
965 multiplier,
966 ref cluster_table,
967 } = table;
968
969 if offset == 0 && multiplier == 1 && decision_prop == 9 && predictor == Predictor::Gradient {
970 return decode_gradient_table(
971 bitstream,
972 decoder,
973 dist_multiplier,
974 grid,
975 value_base,
976 cluster_table,
977 );
978 }
979
980 decode_simple_table_slow(
981 bitstream,
982 decoder,
983 dist_multiplier,
984 predictor_state,
985 wp_header,
986 grid,
987 table,
988 )
989}
990
991#[inline(always)]
992fn cluster_from_table(sample: i32, value_base: i32, cluster_table: &[u8]) -> u8 {
993 let index = (sample - value_base).clamp(0, cluster_table.len() as i32 - 1);
994 cluster_table[index as usize]
995}
996
997fn decode_gradient_table<S: Sample>(
998 bitstream: &mut Bitstream,
999 decoder: &mut Decoder,
1000 dist_multiplier: u32,
1001 grid: &mut MutableSubgrid<S>,
1002 value_base: i32,
1003 cluster_table: &[u8],
1004) -> Result<()> {
1005 tracing::trace!("Gradient-only lookup table");
1006
1007 let width = grid.width();
1008 let height = grid.height();
1009
1010 {
1011 let mut w = S::default();
1012 let out_row = grid.get_row_mut(0);
1013 for out in out_row[..width].iter_mut() {
1014 let cluster = cluster_from_table(w.to_i32(), value_base, cluster_table);
1015 let token = decoder.read_varint_with_multiplier_clustered(
1016 bitstream,
1017 cluster,
1018 dist_multiplier,
1019 )?;
1020 w = S::unpack_signed_u32(token).add(w);
1021 *out = w;
1022 }
1023 }
1024
1025 for y in 1..height {
1026 let (u, mut d) = grid.split_vertical(y);
1027 let prev_row = u.get_row(y - 1);
1028 let out_row = d.get_row_mut(0);
1029
1030 let cluster = cluster_from_table(prev_row[0].to_i32(), value_base, cluster_table);
1031 let token =
1032 decoder.read_varint_with_multiplier_clustered(bitstream, cluster, dist_multiplier)?;
1033 let mut w = S::unpack_signed_u32(token).add(prev_row[0]);
1034 out_row[0] = w;
1035
1036 for (window, out) in prev_row.windows(2).zip(&mut out_row[1..]) {
1037 let nw = window[0];
1038 let n = window[1];
1039 let prop = n
1040 .to_i32()
1041 .wrapping_add(w.to_i32())
1042 .wrapping_sub(nw.to_i32());
1043 let pred = S::grad_clamped(n, w, nw);
1044
1045 let cluster = cluster_from_table(prop, value_base, cluster_table);
1046 let token = decoder.read_varint_with_multiplier_clustered(
1047 bitstream,
1048 cluster,
1049 dist_multiplier,
1050 )?;
1051 let value = S::unpack_signed_u32(token).add(pred);
1052 *out = value;
1053 w = value;
1054 }
1055 }
1056
1057 Ok(())
1058}
1059
1060#[inline(always)]
1061fn decode_table_one<S: Sample, const EDGE: bool>(
1062 bitstream: &mut Bitstream,
1063 decoder: &mut Decoder,
1064 dist_multiplier: u32,
1065 table: &SimpleMaTable,
1066 properties: &Properties<S>,
1067) -> Result<S> {
1068 let prop_value = properties.get(table.decision_prop as usize);
1069
1070 let cluster = cluster_from_table(prop_value, table.value_base, &table.cluster_table);
1071
1072 let diff = S::unpack_signed_u32(decoder.read_varint_with_multiplier_clustered(
1073 bitstream,
1074 cluster,
1075 dist_multiplier,
1076 )?);
1077 let diff = diff.wrapping_muladd_i32(table.multiplier as i32, table.offset);
1078 let predictor = table.predictor;
1079 let sample_prediction = predictor.predict::<_, EDGE>(properties);
1080 Ok(diff.add(S::from_i32(sample_prediction)))
1081}
1082
1083#[inline(never)]
1084fn decode_simple_table_slow<S: Sample>(
1085 bitstream: &mut Bitstream,
1086 decoder: &mut Decoder,
1087 dist_multiplier: u32,
1088 predictor_state: &mut PredictorState<S>,
1089 wp_header: &WpHeader,
1090 grid: &mut MutableSubgrid<S>,
1091 table: &SimpleMaTable,
1092) -> Result<()> {
1093 tracing::trace!("Slow lookup table");
1094
1095 let need_wp_header = table.decision_prop == 15 || table.predictor == Predictor::SelfCorrecting;
1096 let wp_header = need_wp_header.then_some(wp_header);
1097 predictor_state.reset(grid.width() as u32, &[], wp_header);
1098
1099 let height = grid.height();
1100 for y in 0..2usize.min(height) {
1101 let row = grid.get_row_mut(y);
1102
1103 for out in row.iter_mut() {
1104 let properties = predictor_state.properties::<true>();
1105 let true_value = decode_table_one::<_, true>(
1106 bitstream,
1107 decoder,
1108 dist_multiplier,
1109 table,
1110 &properties,
1111 )?;
1112 *out = true_value;
1113 properties.record(true_value.to_i32());
1114 }
1115 }
1116
1117 for y in 2..height {
1118 let row = grid.get_row_mut(y);
1119 let (row_left, row_middle, row_right) = if row.len() <= 4 {
1120 (row, [].as_mut(), [].as_mut())
1121 } else {
1122 let (l, m) = row.split_at_mut(2);
1123 let (m, r) = m.split_at_mut(m.len() - 2);
1124 (l, m, r)
1125 };
1126
1127 for out in row_left {
1128 let properties = predictor_state.properties::<true>();
1129 let true_value = decode_table_one::<_, true>(
1130 bitstream,
1131 decoder,
1132 dist_multiplier,
1133 table,
1134 &properties,
1135 )?;
1136 *out = true_value;
1137 properties.record(true_value.to_i32());
1138 }
1139 for out in row_middle {
1140 let properties = predictor_state.properties::<false>();
1141 let true_value = decode_table_one::<_, false>(
1142 bitstream,
1143 decoder,
1144 dist_multiplier,
1145 table,
1146 &properties,
1147 )?;
1148 *out = true_value;
1149 properties.record(true_value.to_i32());
1150 }
1151 for out in row_right {
1152 let properties = predictor_state.properties::<true>();
1153 let true_value = decode_table_one::<_, true>(
1154 bitstream,
1155 decoder,
1156 dist_multiplier,
1157 table,
1158 &properties,
1159 )?;
1160 *out = true_value;
1161 properties.record(true_value.to_i32());
1162 }
1163 }
1164
1165 Ok(())
1166}
1167
1168#[inline(never)]
1169fn decode_slow<S: Sample>(
1170 bitstream: &mut Bitstream,
1171 decoder: &mut Decoder,
1172 dist_multiplier: u32,
1173 ma_tree: &FlatMaTree,
1174 predictor: &mut PredictorState<S>,
1175 grid: &mut MutableSubgrid<S>,
1176) -> Result<()> {
1177 let height = grid.height();
1178 for y in 0..2usize.min(height) {
1179 let row = grid.get_row_mut(y);
1180
1181 for out in row.iter_mut() {
1182 let properties = predictor.properties::<true>();
1183 let leaf = ma_tree.get_leaf(&properties);
1184 let true_value =
1185 decode_one::<_, true>(bitstream, decoder, dist_multiplier, leaf, &properties)?;
1186 *out = true_value;
1187 properties.record(true_value.to_i32());
1188 }
1189 }
1190
1191 for y in 2..height {
1192 let row = grid.get_row_mut(y);
1193 let (row_left, row_middle, row_right) = if row.len() <= 4 {
1194 (row, [].as_mut(), [].as_mut())
1195 } else {
1196 let (l, m) = row.split_at_mut(2);
1197 let (m, r) = m.split_at_mut(m.len() - 2);
1198 (l, m, r)
1199 };
1200
1201 for out in row_left {
1202 let properties = predictor.properties::<true>();
1203 let leaf = ma_tree.get_leaf(&properties);
1204 let true_value =
1205 decode_one::<_, true>(bitstream, decoder, dist_multiplier, leaf, &properties)?;
1206 *out = true_value;
1207 properties.record(true_value.to_i32());
1208 }
1209 for out in row_middle {
1210 let properties = predictor.properties::<false>();
1211 let leaf = ma_tree.get_leaf(&properties);
1212 let true_value =
1213 decode_one::<_, false>(bitstream, decoder, dist_multiplier, leaf, &properties)?;
1214 *out = true_value;
1215 properties.record(true_value.to_i32());
1216 }
1217 for out in row_right {
1218 let properties = predictor.properties::<true>();
1219 let leaf = ma_tree.get_leaf(&properties);
1220 let true_value =
1221 decode_one::<_, true>(bitstream, decoder, dist_multiplier, leaf, &properties)?;
1222 *out = true_value;
1223 properties.record(true_value.to_i32());
1224 }
1225 }
1226
1227 Ok(())
1228}