jxl_modular/
image.rs

1use std::collections::HashMap;
2
3use jxl_bitstream::Bitstream;
4use jxl_coding::{Decoder, DecoderRleMode, RleToken};
5use jxl_grid::{AlignedGrid, AllocTracker, MutableSubgrid};
6
7use crate::{
8    MaConfig, ModularChannelInfo, ModularChannels, ModularHeader, Result,
9    ma::{FlatMaTree, MaTreeLeafClustered, SimpleMaTable},
10    predictor::{Predictor, PredictorState, Properties, WpHeader},
11    sample::Sample,
12};
13
14#[derive(Debug)]
15pub enum TransformedGrid<'dest, S: Sample> {
16    Single(MutableSubgrid<'dest, S>),
17    Merged {
18        leader: MutableSubgrid<'dest, S>,
19        members: Vec<TransformedGrid<'dest, S>>,
20    },
21}
22
23impl<'dest, S: Sample> From<MutableSubgrid<'dest, S>> for TransformedGrid<'dest, S> {
24    fn from(value: MutableSubgrid<'dest, S>) -> Self {
25        Self::Single(value)
26    }
27}
28
29impl<S: Sample> TransformedGrid<'_, S> {
30    fn reborrow(&mut self) -> TransformedGrid<S> {
31        match self {
32            TransformedGrid::Single(g) => TransformedGrid::Single(g.split_horizontal(0).1),
33            TransformedGrid::Merged { leader, .. } => {
34                TransformedGrid::Single(leader.split_horizontal(0).1)
35            }
36        }
37    }
38}
39
40impl<'dest, S: Sample> TransformedGrid<'dest, S> {
41    pub(crate) fn grid(&self) -> &MutableSubgrid<'dest, S> {
42        match self {
43            Self::Single(g) => g,
44            Self::Merged { leader, .. } => leader,
45        }
46    }
47
48    pub(crate) fn grid_mut(&mut self) -> &mut MutableSubgrid<'dest, S> {
49        match self {
50            Self::Single(g) => g,
51            Self::Merged { leader, .. } => leader,
52        }
53    }
54
55    pub(crate) fn merge(&mut self, members: Vec<TransformedGrid<'dest, S>>) {
56        if members.is_empty() {
57            return;
58        }
59
60        match self {
61            Self::Single(leader) => {
62                let tmp = MutableSubgrid::empty();
63                let leader = std::mem::replace(leader, tmp);
64                *self = Self::Merged { leader, members };
65            }
66            Self::Merged {
67                members: original_members,
68                ..
69            } => {
70                original_members.extend(members);
71            }
72        }
73    }
74
75    pub(crate) fn unmerge(&mut self, count: usize) -> Vec<TransformedGrid<'dest, S>> {
76        if count == 0 {
77            return Vec::new();
78        }
79
80        match self {
81            Self::Single(_) => panic!("cannot unmerge TransformedGrid::Single"),
82            Self::Merged { leader, members } => {
83                let len = members.len();
84                let members = members.drain((len - count)..).collect();
85                if len == count {
86                    let tmp = MutableSubgrid::empty();
87                    let leader = std::mem::replace(leader, tmp);
88                    *self = Self::Single(leader);
89                }
90                members
91            }
92        }
93    }
94}
95
96#[derive(Debug)]
97pub struct ModularImageDestination<S: Sample> {
98    header: ModularHeader,
99    ma_ctx: MaConfig,
100    group_dim: u32,
101    bit_depth: u32,
102    channels: ModularChannels,
103    meta_channels: Vec<AlignedGrid<S>>,
104    image_channels: Vec<AlignedGrid<S>>,
105}
106
107impl<S: Sample> ModularImageDestination<S> {
108    pub(crate) fn new(
109        header: ModularHeader,
110        ma_ctx: MaConfig,
111        group_dim: u32,
112        bit_depth: u32,
113        channels: ModularChannels,
114        tracker: Option<&AllocTracker>,
115    ) -> Result<Self> {
116        let mut meta_channels = Vec::new();
117        for tr in &header.transform {
118            tr.prepare_meta_channels(&mut meta_channels, tracker)?;
119        }
120
121        let image_channels = channels
122            .info
123            .iter()
124            .map(|ch| {
125                AlignedGrid::with_alloc_tracker(ch.width as usize, ch.height as usize, tracker)
126            })
127            .collect::<std::result::Result<_, _>>()?;
128
129        Ok(Self {
130            header,
131            ma_ctx,
132            group_dim,
133            bit_depth,
134            channels,
135            meta_channels,
136            image_channels,
137        })
138    }
139
140    pub fn try_clone(&self) -> Result<Self> {
141        Ok(Self {
142            header: self.header.clone(),
143            ma_ctx: self.ma_ctx.clone(),
144            group_dim: self.group_dim,
145            bit_depth: self.bit_depth,
146            channels: self.channels.clone(),
147            meta_channels: self
148                .meta_channels
149                .iter()
150                .map(|x| x.try_clone())
151                .collect::<std::result::Result<_, _>>()?,
152            image_channels: self
153                .image_channels
154                .iter()
155                .map(|x| x.try_clone())
156                .collect::<std::result::Result<_, _>>()?,
157        })
158    }
159
160    pub fn image_channels(&self) -> &[AlignedGrid<S>] {
161        &self.image_channels
162    }
163
164    pub fn into_image_channels(self) -> Vec<AlignedGrid<S>> {
165        self.image_channels
166    }
167
168    pub fn into_image_channels_with_info(
169        self,
170    ) -> impl Iterator<Item = (AlignedGrid<S>, ModularChannelInfo)> {
171        self.image_channels.into_iter().zip(self.channels.info)
172    }
173
174    pub fn has_palette(&self) -> bool {
175        self.header.transform.iter().any(|tr| tr.is_palette())
176    }
177
178    pub fn has_squeeze(&self) -> bool {
179        self.header.transform.iter().any(|tr| tr.is_squeeze())
180    }
181}
182
183impl<S: Sample> ModularImageDestination<S> {
184    pub fn prepare_gmodular(&mut self) -> Result<TransformedModularSubimage<S>> {
185        assert_ne!(self.group_dim, 0);
186
187        let group_dim = self.group_dim;
188        let subimage = self.prepare_subimage()?;
189        let (channel_info, grids): (Vec<_>, Vec<_>) = subimage
190            .channel_info
191            .into_iter()
192            .zip(subimage.grid)
193            .enumerate()
194            .take_while(|&(i, (ref info, _))| {
195                i < subimage.nb_meta_channels
196                    || (info.width <= group_dim && info.height <= group_dim)
197            })
198            .map(|(_, x)| x)
199            .unzip();
200        let channel_indices = (0..channel_info.len()).collect();
201        Ok(TransformedModularSubimage {
202            channel_info,
203            channel_indices,
204            grid: grids,
205            ..subimage
206        })
207    }
208
209    pub fn prepare_groups(
210        &mut self,
211        pass_shifts: &std::collections::BTreeMap<u32, (i32, i32)>,
212    ) -> Result<TransformedGlobalModular<S>> {
213        assert_ne!(self.group_dim, 0);
214
215        let num_passes = *pass_shifts.last_key_value().unwrap().0 as usize + 1;
216
217        let group_dim = self.group_dim;
218        let group_dim_shift = group_dim.trailing_zeros();
219        let bit_depth = self.bit_depth;
220        let subimage = self.prepare_subimage()?;
221        let it = subimage
222            .channel_info
223            .into_iter()
224            .zip(subimage.grid)
225            .enumerate()
226            .skip_while(|&(i, (ref info, _))| {
227                i < subimage.nb_meta_channels
228                    || (info.width <= group_dim && info.height <= group_dim)
229            });
230
231        let mut lf_groups = Vec::new();
232        let mut pass_groups = Vec::with_capacity(num_passes);
233        pass_groups.resize_with(num_passes, Vec::new);
234        for (i, (info, grid)) in it {
235            let ModularChannelInfo {
236                original_width,
237                original_height,
238                hshift,
239                vshift,
240                original_shift,
241                ..
242            } = info;
243            assert!(hshift >= 0 && vshift >= 0);
244
245            let grid = match grid {
246                TransformedGrid::Single(g) => g,
247                TransformedGrid::Merged { leader, .. } => leader,
248            };
249            tracing::trace!(
250                i,
251                width = grid.width(),
252                height = grid.height(),
253                hshift,
254                vshift
255            );
256
257            let (groups, grids) = if hshift < 3 || vshift < 3 {
258                let shift = hshift.min(vshift); // shift < 3
259                let pass_idx = *pass_shifts
260                    .iter()
261                    .find(|&(_, &(minshift, maxshift))| (minshift..maxshift).contains(&shift))
262                    .unwrap()
263                    .0;
264                let pass_idx = pass_idx as usize;
265
266                let group_width = group_dim >> hshift;
267                let group_height = group_dim >> vshift;
268                if group_width == 0 || group_height == 0 {
269                    tracing::error!(
270                        group_dim,
271                        hshift,
272                        vshift,
273                        "Channel shift value too large after transform"
274                    );
275                    return Err(crate::Error::InvalidSqueezeParams);
276                }
277
278                let grids = grid.into_groups_with_fixed_count(
279                    group_width as usize,
280                    group_height as usize,
281                    (original_width + group_dim - 1) as usize >> group_dim_shift,
282                    (original_height + group_dim - 1) as usize >> group_dim_shift,
283                );
284                (&mut pass_groups[pass_idx], grids)
285            } else {
286                // hshift >= 3 && vshift >= 3
287                let lf_group_width = group_dim >> (hshift - 3);
288                let lf_group_height = group_dim >> (vshift - 3);
289                if lf_group_width == 0 || lf_group_height == 0 {
290                    tracing::error!(
291                        group_dim,
292                        hshift,
293                        vshift,
294                        "Channel shift value too large after transform"
295                    );
296                    return Err(crate::Error::InvalidSqueezeParams);
297                }
298                let grids = grid.into_groups_with_fixed_count(
299                    lf_group_width as usize,
300                    lf_group_height as usize,
301                    (original_width + (group_dim << 3) - 1) as usize >> (group_dim_shift + 3),
302                    (original_height + (group_dim << 3) - 1) as usize >> (group_dim_shift + 3),
303                );
304                (&mut lf_groups, grids)
305            };
306
307            if groups.is_empty() {
308                groups.resize_with(grids.len(), || {
309                    TransformedModularSubimage::empty(&subimage.header, &subimage.ma_ctx, bit_depth)
310                });
311            } else if groups.len() != grids.len() {
312                panic!();
313            }
314
315            for (subimage, grid) in groups.iter_mut().zip(grids) {
316                let width = grid.width() as u32;
317                let height = grid.height() as u32;
318                if width == 0 || height == 0 {
319                    continue;
320                }
321
322                subimage.channel_info.push(ModularChannelInfo {
323                    width,
324                    height,
325                    original_width: width << hshift,
326                    original_height: height << vshift,
327                    hshift,
328                    vshift,
329                    original_shift,
330                });
331                subimage.channel_indices.push(i);
332                subimage.grid.push(grid.into());
333                subimage.partial = true;
334            }
335        }
336
337        Ok(TransformedGlobalModular {
338            lf_groups,
339            pass_groups,
340        })
341    }
342
343    pub fn prepare_subimage(&mut self) -> Result<TransformedModularSubimage<S>> {
344        let mut channels = self.channels.clone();
345        let mut meta_channel_grids = self
346            .meta_channels
347            .iter_mut()
348            .map(MutableSubgrid::from)
349            .collect::<Vec<_>>();
350        let mut grids = self
351            .image_channels
352            .iter_mut()
353            .map(|g| g.as_subgrid_mut().into())
354            .collect::<Vec<_>>();
355        for tr in &self.header.transform {
356            tr.transform_channels(&mut channels, &mut meta_channel_grids, &mut grids)?;
357        }
358
359        let channel_info = channels.info;
360        let channel_indices = (0..channel_info.len()).collect();
361        Ok(TransformedModularSubimage {
362            header: self.header.clone(),
363            ma_ctx: self.ma_ctx.clone(),
364            bit_depth: self.bit_depth,
365            nb_meta_channels: channels.nb_meta_channels as usize,
366            channel_info,
367            channel_indices,
368            grid: grids,
369            partial: true,
370        })
371    }
372}
373
374#[derive(Debug)]
375pub struct TransformedGlobalModular<'dest, S: Sample> {
376    pub lf_groups: Vec<TransformedModularSubimage<'dest, S>>,
377    pub pass_groups: Vec<Vec<TransformedModularSubimage<'dest, S>>>,
378}
379
380#[derive(Debug)]
381pub struct TransformedModularSubimage<'dest, S: Sample> {
382    header: ModularHeader,
383    ma_ctx: MaConfig,
384    bit_depth: u32,
385    nb_meta_channels: usize,
386    channel_info: Vec<ModularChannelInfo>,
387    channel_indices: Vec<usize>,
388    grid: Vec<TransformedGrid<'dest, S>>,
389    partial: bool,
390}
391
392impl<S: Sample> TransformedModularSubimage<'_, S> {
393    fn empty(header: &ModularHeader, ma_ctx: &MaConfig, bit_depth: u32) -> Self {
394        Self {
395            header: header.clone(),
396            ma_ctx: ma_ctx.clone(),
397            bit_depth,
398            nb_meta_channels: 0,
399            channel_info: Vec::new(),
400            channel_indices: Vec::new(),
401            grid: Vec::new(),
402            partial: false,
403        }
404    }
405}
406
407impl<'dest, S: Sample> TransformedModularSubimage<'dest, S> {
408    pub fn is_empty(&self) -> bool {
409        self.channel_info.is_empty()
410    }
411
412    pub fn is_partial(&self) -> bool {
413        self.partial
414    }
415
416    pub fn recursive(
417        self,
418        bitstream: &mut Bitstream,
419        global_ma_config: Option<&MaConfig>,
420        tracker: Option<&AllocTracker>,
421    ) -> Result<RecursiveModularImage<'dest, S>> {
422        let channels = crate::ModularChannels {
423            info: self.channel_info,
424            nb_meta_channels: 0,
425        };
426        let (header, ma_ctx) = crate::read_and_validate_local_modular_header(
427            bitstream,
428            &channels,
429            global_ma_config,
430            tracker,
431        )?;
432
433        let mut image = RecursiveModularImage {
434            header,
435            ma_ctx,
436            bit_depth: self.bit_depth,
437            channels,
438            meta_channels: Vec::new(),
439            image_channels: self.grid,
440        };
441        for tr in &image.header.transform {
442            tr.prepare_meta_channels(&mut image.meta_channels, tracker)?;
443        }
444        Ok(image)
445    }
446
447    pub fn finish(mut self, pool: &jxl_threadpool::JxlThreadPool) -> bool {
448        for tr in self.header.transform.iter().rev() {
449            tr.inverse(&mut self.grid, self.bit_depth, pool);
450        }
451        !self.partial
452    }
453}
454
455impl<S: Sample> TransformedModularSubimage<'_, S> {
456    fn decode_inner(&mut self, bitstream: &mut Bitstream, stream_index: u32) -> Result<()> {
457        let span = tracing::span!(tracing::Level::TRACE, "decode channels", stream_index);
458        let _guard = span.enter();
459
460        let dist_multiplier = self
461            .channel_info
462            .iter()
463            .map(|info| info.width)
464            .max()
465            .unwrap_or(0);
466
467        let mut decoder = self.ma_ctx.decoder().clone();
468        decoder.begin(bitstream)?;
469
470        let mut ma_tree_list = Vec::with_capacity(self.channel_info.len());
471        for (i, info) in self.channel_info.iter().enumerate() {
472            if info.width == 0 || info.height == 0 {
473                ma_tree_list.push(None);
474                continue;
475            }
476
477            let filtered_prev_len = self.channel_info[..i]
478                .iter()
479                .filter(|prev_info| {
480                    info.width == prev_info.width
481                        && info.height == prev_info.height
482                        && info.hshift == prev_info.hshift
483                        && info.vshift == prev_info.vshift
484                })
485                .count();
486
487            let ma_tree =
488                self.ma_ctx
489                    .make_flat_tree(i as u32, stream_index, filtered_prev_len as u32);
490            ma_tree_list.push(Some(ma_tree));
491        }
492
493        if let Some(mut rle_decoder) = decoder.as_rle() {
494            let is_fast_lossless = ma_tree_list.iter().all(|ma_tree| {
495                ma_tree
496                    .as_ref()
497                    .map(|ma_tree| {
498                        matches!(
499                            ma_tree.single_node(),
500                            Some(MaTreeLeafClustered {
501                                predictor: Predictor::Gradient,
502                                offset: 0,
503                                multiplier: 1,
504                                ..
505                            })
506                        )
507                    })
508                    .unwrap_or(true)
509            });
510
511            if is_fast_lossless {
512                tracing::trace!("libjxl fast-lossless");
513                let mut rle_state = RleState::<S>::new();
514
515                for (ma_tree, grid) in ma_tree_list.into_iter().zip(&mut self.grid) {
516                    let Some(ma_tree) = ma_tree else {
517                        continue;
518                    };
519
520                    let node = ma_tree.single_node().unwrap();
521                    let cluster = node.cluster;
522                    decode_fast_lossless(
523                        bitstream,
524                        &mut rle_decoder,
525                        &mut rle_state,
526                        cluster,
527                        grid.grid_mut(),
528                    );
529                }
530
531                rle_state.check_error()?;
532                // Prefix code doesn't have checksum
533                return Ok(());
534            }
535        }
536
537        let wp_header = &self.header.wp_params;
538        let mut predictor = PredictorState::new();
539        let mut prev_map = HashMap::new();
540        for ((info, ma_tree), grid) in self
541            .channel_info
542            .iter()
543            .zip(ma_tree_list)
544            .zip(&mut self.grid)
545        {
546            let Some(ma_tree) = ma_tree else {
547                continue;
548            };
549            let key = (info.width, info.height, info.hshift, info.vshift);
550
551            let filtered_prev = prev_map.entry(key).or_insert_with(Vec::new);
552
553            if let Some(node) = ma_tree.single_node() {
554                decode_single_node(
555                    bitstream,
556                    &mut decoder,
557                    dist_multiplier,
558                    &mut predictor,
559                    wp_header,
560                    grid.grid_mut(),
561                    node,
562                )?;
563            } else if let Some(table) = ma_tree.simple_table() {
564                decode_simple_table(
565                    bitstream,
566                    &mut decoder,
567                    dist_multiplier,
568                    &mut predictor,
569                    wp_header,
570                    grid.grid_mut(),
571                    &table,
572                )?;
573            } else {
574                let grid = grid.grid_mut();
575                let filtered_prev = &filtered_prev[..ma_tree.max_prev_channel_depth()];
576                let wp_header = ma_tree.need_self_correcting().then_some(wp_header);
577                predictor.reset(grid.width() as u32, filtered_prev, wp_header);
578                decode_slow(
579                    bitstream,
580                    &mut decoder,
581                    dist_multiplier,
582                    &ma_tree,
583                    &mut predictor,
584                    grid,
585                )?;
586            }
587
588            filtered_prev.insert(0, grid.grid());
589        }
590
591        decoder.finalize()?;
592        Ok(())
593    }
594
595    pub fn decode(
596        &mut self,
597        bitstream: &mut Bitstream,
598        stream_index: u32,
599        allow_partial: bool,
600    ) -> Result<()> {
601        match self.decode_inner(bitstream, stream_index) {
602            Err(e) if e.unexpected_eof() && allow_partial => {
603                tracing::debug!("Partially decoded Modular image");
604            }
605            Err(e) => return Err(e),
606            Ok(_) => {
607                self.partial = false;
608            }
609        }
610        Ok(())
611    }
612}
613
614#[derive(Debug)]
615pub struct RecursiveModularImage<'dest, S: Sample> {
616    header: ModularHeader,
617    ma_ctx: MaConfig,
618    bit_depth: u32,
619    channels: ModularChannels,
620    meta_channels: Vec<AlignedGrid<S>>,
621    image_channels: Vec<TransformedGrid<'dest, S>>,
622}
623
624impl<S: Sample> RecursiveModularImage<'_, S> {
625    pub fn prepare_subimage(&mut self) -> Result<TransformedModularSubimage<S>> {
626        let mut channels = self.channels.clone();
627        let mut meta_channel_grids = self
628            .meta_channels
629            .iter_mut()
630            .map(|g| {
631                let width = g.width();
632                let height = g.height();
633                MutableSubgrid::from_buf(g.buf_mut(), width, height, width)
634            })
635            .collect::<Vec<_>>();
636        let mut grids = self
637            .image_channels
638            .iter_mut()
639            .map(|g| g.reborrow())
640            .collect();
641        for tr in &self.header.transform {
642            tr.transform_channels(&mut channels, &mut meta_channel_grids, &mut grids)?;
643        }
644
645        let channel_info = channels.info;
646        let channel_indices = (0..channel_info.len()).collect();
647        Ok(TransformedModularSubimage {
648            header: self.header.clone(),
649            ma_ctx: self.ma_ctx.clone(),
650            bit_depth: self.bit_depth,
651            nb_meta_channels: channels.nb_meta_channels as usize,
652            channel_info,
653            channel_indices,
654            grid: grids,
655            partial: true,
656        })
657    }
658}
659
660struct RleState<S: Sample> {
661    value: S,
662    repeat: u32,
663    error: Option<Box<jxl_coding::Error>>,
664}
665
666impl<S: Sample> RleState<S> {
667    #[inline]
668    fn new() -> Self {
669        Self {
670            value: S::default(),
671            repeat: 0,
672            error: None,
673        }
674    }
675
676    #[inline(always)]
677    fn decode(
678        &mut self,
679        bitstream: &mut Bitstream,
680        decoder: &mut DecoderRleMode,
681        cluster: u8,
682    ) -> S {
683        if self.repeat == 0 {
684            let result = decoder.read_varint_clustered(bitstream, cluster);
685            match result {
686                Ok(RleToken::Value(v)) => {
687                    self.value = S::unpack_signed_u32(v);
688                    self.repeat = 1;
689                }
690                Ok(RleToken::Repeat(len)) => {
691                    self.repeat = len;
692                }
693                Err(e) if self.error.is_none() => {
694                    self.error = Some(Box::new(e));
695                }
696                _ => {}
697            }
698        }
699
700        self.repeat = self.repeat.wrapping_sub(1);
701        self.value
702    }
703
704    #[inline]
705    fn check_error(&mut self) -> Result<()> {
706        let error = self.error.take();
707        if let Some(error) = error {
708            let error = *error;
709            Err(error.into())
710        } else {
711            Ok(())
712        }
713    }
714}
715
716fn decode_single_node<S: Sample>(
717    bitstream: &mut Bitstream,
718    decoder: &mut Decoder,
719    dist_multiplier: u32,
720    predictor_state: &mut PredictorState<S>,
721    wp_header: &WpHeader,
722    grid: &mut MutableSubgrid<S>,
723    node: &MaTreeLeafClustered,
724) -> Result<()> {
725    let &MaTreeLeafClustered {
726        cluster,
727        predictor,
728        offset,
729        multiplier,
730    } = node;
731    tracing::trace!(cluster, ?predictor, "Single MA tree node");
732
733    let height = grid.height();
734    let single_token = decoder.single_token(cluster);
735    match (predictor, single_token) {
736        (Predictor::Zero, Some(token)) => {
737            tracing::trace!("Single token in cluster, Zero predictor: hyper fast path");
738            let value = S::unpack_signed_u32(token).wrapping_muladd_i32(multiplier as i32, offset);
739            for y in 0..height {
740                grid.get_row_mut(y).fill(value);
741            }
742            Ok(())
743        }
744        (Predictor::Zero, None) => {
745            tracing::trace!("Zero predictor: fast path");
746            for y in 0..height {
747                let row = grid.get_row_mut(y);
748                for out in row {
749                    let token = decoder.read_varint_with_multiplier_clustered(
750                        bitstream,
751                        cluster,
752                        dist_multiplier,
753                    )?;
754                    *out =
755                        S::unpack_signed_u32(token).wrapping_muladd_i32(multiplier as i32, offset);
756                }
757            }
758            Ok(())
759        }
760        (Predictor::Gradient, _) if offset == 0 && multiplier == 1 => {
761            tracing::trace!("Simple gradient: quite fast path");
762            decode_simple_grad(bitstream, decoder, cluster, dist_multiplier, grid)
763        }
764        _ => {
765            let wp_header = (predictor == Predictor::SelfCorrecting).then_some(wp_header);
766            predictor_state.reset(grid.width() as u32, &[], wp_header);
767            decode_single_node_slow(
768                bitstream,
769                decoder,
770                dist_multiplier,
771                node,
772                predictor_state,
773                grid,
774            )
775        }
776    }
777}
778
779#[inline(never)]
780fn decode_fast_lossless<S: Sample>(
781    bitstream: &mut Bitstream,
782    decoder: &mut DecoderRleMode,
783    rle_state: &mut RleState<S>,
784    cluster: u8,
785    grid: &mut MutableSubgrid<S>,
786) {
787    let height = grid.height();
788
789    {
790        let mut w = S::default();
791        let out_row = grid.get_row_mut(0);
792        for out in &mut *out_row {
793            let token = rle_state.decode(bitstream, decoder, cluster);
794            w = w.add(token);
795            *out = w;
796        }
797    }
798
799    for y in 1..height {
800        let (u, mut d) = grid.split_vertical(y);
801        let prev_row = u.get_row(y - 1);
802        let out_row = d.get_row_mut(0);
803
804        let token = rle_state.decode(bitstream, decoder, cluster);
805        let mut w = token.add(prev_row[0]);
806        out_row[0] = w;
807
808        for (window, out) in prev_row.windows(2).zip(&mut out_row[1..]) {
809            let nw = window[0];
810            let n = window[1];
811            let pred = S::grad_clamped(n, w, nw);
812
813            let token = rle_state.decode(bitstream, decoder, cluster);
814            w = token.add(pred);
815            *out = w;
816        }
817    }
818}
819
820#[inline(never)]
821fn decode_simple_grad<S: Sample>(
822    bitstream: &mut Bitstream,
823    decoder: &mut Decoder,
824    cluster: u8,
825    dist_multiplier: u32,
826    grid: &mut MutableSubgrid<S>,
827) -> Result<()> {
828    let width = grid.width();
829    let height = grid.height();
830
831    {
832        let mut w = S::default();
833        let out_row = grid.get_row_mut(0);
834        for out in out_row[..width].iter_mut() {
835            let token = decoder.read_varint_with_multiplier_clustered(
836                bitstream,
837                cluster,
838                dist_multiplier,
839            )?;
840            w = S::unpack_signed_u32(token).add(w);
841            *out = w;
842        }
843    }
844
845    for y in 1..height {
846        let (u, mut d) = grid.split_vertical(y);
847        let prev_row = u.get_row(y - 1);
848        let out_row = d.get_row_mut(0);
849
850        let token =
851            decoder.read_varint_with_multiplier_clustered(bitstream, cluster, dist_multiplier)?;
852        let mut w = S::unpack_signed_u32(token).add(prev_row[0]);
853        out_row[0] = w;
854
855        for (window, out) in prev_row.windows(2).zip(&mut out_row[1..]) {
856            let nw = window[0];
857            let n = window[1];
858            let pred = S::grad_clamped(n, w, nw);
859
860            let token = decoder.read_varint_with_multiplier_clustered(
861                bitstream,
862                cluster,
863                dist_multiplier,
864            )?;
865            let value = S::unpack_signed_u32(token).add(pred);
866            *out = value;
867            w = value;
868        }
869    }
870
871    Ok(())
872}
873
874#[inline(always)]
875fn decode_one<S: Sample, const EDGE: bool>(
876    bitstream: &mut Bitstream,
877    decoder: &mut Decoder,
878    dist_multiplier: u32,
879    leaf: &MaTreeLeafClustered,
880    properties: &Properties<S>,
881) -> Result<S> {
882    let diff = S::unpack_signed_u32(decoder.read_varint_with_multiplier_clustered(
883        bitstream,
884        leaf.cluster,
885        dist_multiplier,
886    )?);
887    let diff = diff.wrapping_muladd_i32(leaf.multiplier as i32, leaf.offset);
888    let predictor = leaf.predictor;
889    let sample_prediction = predictor.predict::<_, EDGE>(properties);
890    Ok(diff.add(S::from_i32(sample_prediction)))
891}
892
893#[inline(never)]
894fn decode_single_node_slow<S: Sample>(
895    bitstream: &mut Bitstream,
896    decoder: &mut Decoder,
897    dist_multiplier: u32,
898    leaf: &MaTreeLeafClustered,
899    predictor: &mut PredictorState<S>,
900    grid: &mut MutableSubgrid<S>,
901) -> Result<()> {
902    let height = grid.height();
903    for y in 0..2usize.min(height) {
904        let row = grid.get_row_mut(y);
905
906        for out in row.iter_mut() {
907            let properties = predictor.properties::<true>();
908            let true_value =
909                decode_one::<_, true>(bitstream, decoder, dist_multiplier, leaf, &properties)?;
910            *out = true_value;
911            properties.record(true_value.to_i32());
912        }
913    }
914
915    for y in 2..height {
916        let row = grid.get_row_mut(y);
917        let (row_left, row_middle, row_right) = if row.len() <= 4 {
918            (row, [].as_mut(), [].as_mut())
919        } else {
920            let (l, m) = row.split_at_mut(2);
921            let (m, r) = m.split_at_mut(m.len() - 2);
922            (l, m, r)
923        };
924
925        for out in row_left {
926            let properties = predictor.properties::<true>();
927            let true_value =
928                decode_one::<_, true>(bitstream, decoder, dist_multiplier, leaf, &properties)?;
929            *out = true_value;
930            properties.record(true_value.to_i32());
931        }
932        for out in row_middle {
933            let properties = predictor.properties::<false>();
934            let true_value =
935                decode_one::<_, false>(bitstream, decoder, dist_multiplier, leaf, &properties)?;
936            *out = true_value;
937            properties.record(true_value.to_i32());
938        }
939        for out in row_right {
940            let properties = predictor.properties::<true>();
941            let true_value =
942                decode_one::<_, true>(bitstream, decoder, dist_multiplier, leaf, &properties)?;
943            *out = true_value;
944            properties.record(true_value.to_i32());
945        }
946    }
947
948    Ok(())
949}
950
951fn decode_simple_table<S: Sample>(
952    bitstream: &mut Bitstream,
953    decoder: &mut Decoder,
954    dist_multiplier: u32,
955    predictor_state: &mut PredictorState<S>,
956    wp_header: &WpHeader,
957    grid: &mut MutableSubgrid<S>,
958    table: &SimpleMaTable,
959) -> Result<()> {
960    let &SimpleMaTable {
961        decision_prop,
962        value_base,
963        predictor,
964        offset,
965        multiplier,
966        ref cluster_table,
967    } = table;
968
969    if offset == 0 && multiplier == 1 && decision_prop == 9 && predictor == Predictor::Gradient {
970        return decode_gradient_table(
971            bitstream,
972            decoder,
973            dist_multiplier,
974            grid,
975            value_base,
976            cluster_table,
977        );
978    }
979
980    decode_simple_table_slow(
981        bitstream,
982        decoder,
983        dist_multiplier,
984        predictor_state,
985        wp_header,
986        grid,
987        table,
988    )
989}
990
991#[inline(always)]
992fn cluster_from_table(sample: i32, value_base: i32, cluster_table: &[u8]) -> u8 {
993    let index = (sample - value_base).clamp(0, cluster_table.len() as i32 - 1);
994    cluster_table[index as usize]
995}
996
997fn decode_gradient_table<S: Sample>(
998    bitstream: &mut Bitstream,
999    decoder: &mut Decoder,
1000    dist_multiplier: u32,
1001    grid: &mut MutableSubgrid<S>,
1002    value_base: i32,
1003    cluster_table: &[u8],
1004) -> Result<()> {
1005    tracing::trace!("Gradient-only lookup table");
1006
1007    let width = grid.width();
1008    let height = grid.height();
1009
1010    {
1011        let mut w = S::default();
1012        let out_row = grid.get_row_mut(0);
1013        for out in out_row[..width].iter_mut() {
1014            let cluster = cluster_from_table(w.to_i32(), value_base, cluster_table);
1015            let token = decoder.read_varint_with_multiplier_clustered(
1016                bitstream,
1017                cluster,
1018                dist_multiplier,
1019            )?;
1020            w = S::unpack_signed_u32(token).add(w);
1021            *out = w;
1022        }
1023    }
1024
1025    for y in 1..height {
1026        let (u, mut d) = grid.split_vertical(y);
1027        let prev_row = u.get_row(y - 1);
1028        let out_row = d.get_row_mut(0);
1029
1030        let cluster = cluster_from_table(prev_row[0].to_i32(), value_base, cluster_table);
1031        let token =
1032            decoder.read_varint_with_multiplier_clustered(bitstream, cluster, dist_multiplier)?;
1033        let mut w = S::unpack_signed_u32(token).add(prev_row[0]);
1034        out_row[0] = w;
1035
1036        for (window, out) in prev_row.windows(2).zip(&mut out_row[1..]) {
1037            let nw = window[0];
1038            let n = window[1];
1039            let prop = n
1040                .to_i32()
1041                .wrapping_add(w.to_i32())
1042                .wrapping_sub(nw.to_i32());
1043            let pred = S::grad_clamped(n, w, nw);
1044
1045            let cluster = cluster_from_table(prop, value_base, cluster_table);
1046            let token = decoder.read_varint_with_multiplier_clustered(
1047                bitstream,
1048                cluster,
1049                dist_multiplier,
1050            )?;
1051            let value = S::unpack_signed_u32(token).add(pred);
1052            *out = value;
1053            w = value;
1054        }
1055    }
1056
1057    Ok(())
1058}
1059
1060#[inline(always)]
1061fn decode_table_one<S: Sample, const EDGE: bool>(
1062    bitstream: &mut Bitstream,
1063    decoder: &mut Decoder,
1064    dist_multiplier: u32,
1065    table: &SimpleMaTable,
1066    properties: &Properties<S>,
1067) -> Result<S> {
1068    let prop_value = properties.get(table.decision_prop as usize);
1069
1070    let cluster = cluster_from_table(prop_value, table.value_base, &table.cluster_table);
1071
1072    let diff = S::unpack_signed_u32(decoder.read_varint_with_multiplier_clustered(
1073        bitstream,
1074        cluster,
1075        dist_multiplier,
1076    )?);
1077    let diff = diff.wrapping_muladd_i32(table.multiplier as i32, table.offset);
1078    let predictor = table.predictor;
1079    let sample_prediction = predictor.predict::<_, EDGE>(properties);
1080    Ok(diff.add(S::from_i32(sample_prediction)))
1081}
1082
1083#[inline(never)]
1084fn decode_simple_table_slow<S: Sample>(
1085    bitstream: &mut Bitstream,
1086    decoder: &mut Decoder,
1087    dist_multiplier: u32,
1088    predictor_state: &mut PredictorState<S>,
1089    wp_header: &WpHeader,
1090    grid: &mut MutableSubgrid<S>,
1091    table: &SimpleMaTable,
1092) -> Result<()> {
1093    tracing::trace!("Slow lookup table");
1094
1095    let need_wp_header = table.decision_prop == 15 || table.predictor == Predictor::SelfCorrecting;
1096    let wp_header = need_wp_header.then_some(wp_header);
1097    predictor_state.reset(grid.width() as u32, &[], wp_header);
1098
1099    let height = grid.height();
1100    for y in 0..2usize.min(height) {
1101        let row = grid.get_row_mut(y);
1102
1103        for out in row.iter_mut() {
1104            let properties = predictor_state.properties::<true>();
1105            let true_value = decode_table_one::<_, true>(
1106                bitstream,
1107                decoder,
1108                dist_multiplier,
1109                table,
1110                &properties,
1111            )?;
1112            *out = true_value;
1113            properties.record(true_value.to_i32());
1114        }
1115    }
1116
1117    for y in 2..height {
1118        let row = grid.get_row_mut(y);
1119        let (row_left, row_middle, row_right) = if row.len() <= 4 {
1120            (row, [].as_mut(), [].as_mut())
1121        } else {
1122            let (l, m) = row.split_at_mut(2);
1123            let (m, r) = m.split_at_mut(m.len() - 2);
1124            (l, m, r)
1125        };
1126
1127        for out in row_left {
1128            let properties = predictor_state.properties::<true>();
1129            let true_value = decode_table_one::<_, true>(
1130                bitstream,
1131                decoder,
1132                dist_multiplier,
1133                table,
1134                &properties,
1135            )?;
1136            *out = true_value;
1137            properties.record(true_value.to_i32());
1138        }
1139        for out in row_middle {
1140            let properties = predictor_state.properties::<false>();
1141            let true_value = decode_table_one::<_, false>(
1142                bitstream,
1143                decoder,
1144                dist_multiplier,
1145                table,
1146                &properties,
1147            )?;
1148            *out = true_value;
1149            properties.record(true_value.to_i32());
1150        }
1151        for out in row_right {
1152            let properties = predictor_state.properties::<true>();
1153            let true_value = decode_table_one::<_, true>(
1154                bitstream,
1155                decoder,
1156                dist_multiplier,
1157                table,
1158                &properties,
1159            )?;
1160            *out = true_value;
1161            properties.record(true_value.to_i32());
1162        }
1163    }
1164
1165    Ok(())
1166}
1167
1168#[inline(never)]
1169fn decode_slow<S: Sample>(
1170    bitstream: &mut Bitstream,
1171    decoder: &mut Decoder,
1172    dist_multiplier: u32,
1173    ma_tree: &FlatMaTree,
1174    predictor: &mut PredictorState<S>,
1175    grid: &mut MutableSubgrid<S>,
1176) -> Result<()> {
1177    let height = grid.height();
1178    for y in 0..2usize.min(height) {
1179        let row = grid.get_row_mut(y);
1180
1181        for out in row.iter_mut() {
1182            let properties = predictor.properties::<true>();
1183            let leaf = ma_tree.get_leaf(&properties);
1184            let true_value =
1185                decode_one::<_, true>(bitstream, decoder, dist_multiplier, leaf, &properties)?;
1186            *out = true_value;
1187            properties.record(true_value.to_i32());
1188        }
1189    }
1190
1191    for y in 2..height {
1192        let row = grid.get_row_mut(y);
1193        let (row_left, row_middle, row_right) = if row.len() <= 4 {
1194            (row, [].as_mut(), [].as_mut())
1195        } else {
1196            let (l, m) = row.split_at_mut(2);
1197            let (m, r) = m.split_at_mut(m.len() - 2);
1198            (l, m, r)
1199        };
1200
1201        for out in row_left {
1202            let properties = predictor.properties::<true>();
1203            let leaf = ma_tree.get_leaf(&properties);
1204            let true_value =
1205                decode_one::<_, true>(bitstream, decoder, dist_multiplier, leaf, &properties)?;
1206            *out = true_value;
1207            properties.record(true_value.to_i32());
1208        }
1209        for out in row_middle {
1210            let properties = predictor.properties::<false>();
1211            let leaf = ma_tree.get_leaf(&properties);
1212            let true_value =
1213                decode_one::<_, false>(bitstream, decoder, dist_multiplier, leaf, &properties)?;
1214            *out = true_value;
1215            properties.record(true_value.to_i32());
1216        }
1217        for out in row_right {
1218            let properties = predictor.properties::<true>();
1219            let leaf = ma_tree.get_leaf(&properties);
1220            let true_value =
1221                decode_one::<_, true>(bitstream, decoder, dist_multiplier, leaf, &properties)?;
1222            *out = true_value;
1223            properties.record(true_value.to_i32());
1224        }
1225    }
1226
1227    Ok(())
1228}