jxl-modular 0.11.3

JPEG XL Modular image decoder, part of jxl-oxide
Documentation
use jxl_grid::{MutableSubgrid, SharedSubgrid};

use crate::{
    Sample,
    predictor::{Predictor, PredictorState},
};

use super::Palette;

#[rustfmt::skip]
const DELTA_PALETTE: [[i16; 3]; 72] = [
    [0, 0, 0], [4, 4, 4], [11, 0, 0], [0, 0, -13], [0, -12, 0], [-10, -10, -10],
    [-18, -18, -18], [-27, -27, -27], [-18, -18, 0], [0, 0, -32], [-32, 0, 0], [-37, -37, -37],
    [0, -32, -32], [24, 24, 45], [50, 50, 50], [-45, -24, -24], [-24, -45, -45], [0, -24, -24],
    [-34, -34, 0], [-24, 0, -24], [-45, -45, -24], [64, 64, 64], [-32, 0, -32], [0, -32, 0],
    [-32, 0, 32], [-24, -45, -24], [45, 24, 45], [24, -24, -45], [-45, -24, 24], [80, 80, 80],
    [64, 0, 0], [0, 0, -64], [0, -64, -64], [-24, -24, 45], [96, 96, 96], [64, 64, 0],
    [45, -24, -24], [34, -34, 0], [112, 112, 112], [24, -45, -45], [45, 45, -24], [0, -32, 32],
    [24, -24, 45], [0, 96, 96], [45, -24, 24], [24, -45, -24], [-24, -45, 24], [0, -64, 0],
    [96, 0, 0], [128, 128, 128], [64, 0, 64], [144, 144, 144], [96, 96, 0], [-36, -36, 36],
    [45, -24, -45], [45, -45, -24], [0, 0, -96], [0, 128, 128], [0, 96, 0], [45, 24, -45],
    [-128, 0, 0], [24, -45, 24], [-45, 24, -45], [64, 0, -64], [64, -64, -64], [96, 0, 96],
    [45, -45, 24], [24, 45, -45], [64, 64, -64], [128, 128, 0], [0, 0, -128], [-24, 45, -45],
];

impl Palette {
    pub(crate) fn inverse_inner<S: Sample>(
        &self,
        palette: SharedSubgrid<S>,
        mut targets: Vec<MutableSubgrid<S>>,
        bit_depth: u32,
    ) {
        let nb_deltas = self.nb_deltas as i32;
        let nb_colors = self.nb_colours as i32;

        let is_simple = {
            let index_grid = targets[0].as_shared();
            let height = index_grid.height();
            (0..height).all(|y| {
                let row = index_grid.get_row(y);
                row.iter()
                    .all(|&index| (0..nb_colors).contains(&index.to_i32()))
            })
        };

        if is_simple {
            return inverse_simple(palette, targets);
        }

        tracing::trace!("Inverse palette, slow path");

        let mut need_delta = Vec::new();
        let width = targets[0].width();
        let height = targets[0].height();
        let channels = targets.len();
        assert_eq!(channels, palette.height());
        for y in 0..height {
            for x in 0..width {
                let index = targets[0].get(x, y).to_i32();
                if index < nb_deltas {
                    need_delta.push((x, y));
                }

                let channels_it = targets.iter_mut().map(|g| g.get_mut(x, y));

                if (0..nb_colors).contains(&index) {
                    for (c, sample) in channels_it.enumerate() {
                        *sample = palette.get(index as usize, c);
                    }
                } else if index >= nb_colors {
                    let index = index - nb_colors;
                    if index < 64 {
                        for (c, sample) in channels_it.enumerate() {
                            *sample = S::from_i32(
                                ((index >> (2 * c)) % 4) * ((1i32 << bit_depth) - 1) / 4
                                    + (1i32 << bit_depth.saturating_sub(3)),
                            );
                        }
                    } else {
                        let mut index = index - 64;
                        for sample in channels_it {
                            *sample = S::from_i32((index % 5) * ((1i32 << bit_depth) - 1) / 4);
                            index /= 5;
                        }
                    }
                } else {
                    for (c, sample) in channels_it.enumerate() {
                        if c >= 3 {
                            *sample = S::default();
                            continue;
                        }

                        let index = -(index + 1);
                        let index = (index % 143) as usize;
                        let mut temp_sample = DELTA_PALETTE[(index + 1) >> 1][c] as i32;
                        if index & 1 == 0 {
                            temp_sample = -temp_sample;
                        }
                        if bit_depth > 8 {
                            temp_sample <<= bit_depth.min(24) - 8;
                        }
                        *sample = S::from_i32(temp_sample);
                    }
                }
            }
        }

        if need_delta.is_empty() {
            return;
        }

        let d_pred = self.d_pred;
        let wp_header = if d_pred == Predictor::SelfCorrecting {
            self.wp_header.as_ref()
        } else {
            None
        };
        let mut predictor = PredictorState::<S>::new();

        'outer: for mut grid in targets {
            predictor.reset(width as u32, &[], wp_header);

            let mut idx = 0;
            for y in 0..height {
                for x in 0..width {
                    let properties = predictor.properties::<true>();
                    let sample = grid.get_mut(x, y);
                    let mut sample_value = sample.to_i32();
                    if need_delta[idx] == (x, y) {
                        let diff = d_pred.predict::<_, true>(&properties);
                        sample_value = sample_value.wrapping_add(diff);
                        *sample = S::from_i32(sample_value);
                        idx += 1;
                        if idx >= need_delta.len() {
                            continue 'outer;
                        }
                    }
                    properties.record(sample_value);
                }
            }
        }
    }
}

#[inline(never)]
fn inverse_simple<S: Sample>(palette: SharedSubgrid<S>, targets: Vec<MutableSubgrid<S>>) {
    let height = targets[0].height();
    let channels = targets.len();
    assert_eq!(channels, palette.height());

    tracing::trace!("Inverse palette, fast path");

    let mut targets_it = targets.into_iter().enumerate();
    let (_, mut index_grid) = targets_it.next().unwrap();
    for (c, mut grid) in targets_it {
        let palette = palette.get_row(c);
        for y in 0..height {
            let index_row = index_grid.get_row(y);
            let grid_row = grid.get_row_mut(y);
            for (index, sample) in index_row.iter().zip(grid_row) {
                *sample = palette[index.to_i32() as usize];
            }
        }
    }

    let palette = palette.get_row(0);
    for y in 0..height {
        let grid_row = index_grid.get_row_mut(y);
        for sample in grid_row {
            *sample = palette[sample.to_i32() as usize];
        }
    }
}