jxl 0.1.4

High performance Rust implementation of a JPEG XL decoder
Documentation
// Copyright (c) the JPEG XL Project Authors. All rights reserved.
//
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.

use std::{collections::VecDeque, ops::Range};

use crate::{
    bit_reader::BitReader,
    entropy_coding::decode::{Histograms, SymbolReader},
    error::Result,
    frame::modular::{
        ModularChannel, Predictor, Tree,
        decode::{
            channel::ModularChannelDecoder,
            common::{make_pixel, precompute_references},
        },
        predict::{PredictionData, WeightedPredictorState},
        tree::{NUM_NONREF_PROPERTIES, PROPERTIES_PER_PREVCHAN, TreeNode, predict},
    },
    headers::modular::GroupHeader,
    image::Image,
};

// TODO(veluca): we probably want to make flatter trees eventually.

pub struct NoWpTree {
    nodes: Vec<TreeNode>,
    references: Image<i32>,
    property_buffer: Vec<i32>,
}

impl NoWpTree {
    fn new(
        nodes: Vec<TreeNode>,
        max_property_count: usize,
        channel: usize,
        stream: usize,
        xsize: usize,
    ) -> Result<Self> {
        let num_ref_props = max_property_count
            .saturating_sub(NUM_NONREF_PROPERTIES)
            .next_multiple_of(PROPERTIES_PER_PREVCHAN);
        let references = Image::<i32>::new((num_ref_props, xsize))?;
        let num_properties = NUM_NONREF_PROPERTIES + num_ref_props;
        let mut property_buffer: Vec<i32> = vec![0; num_properties];

        property_buffer[0] = channel as i32;
        property_buffer[1] = stream as i32;

        Ok(Self {
            nodes,
            references,
            property_buffer,
        })
    }
}

impl ModularChannelDecoder for NoWpTree {
    const NEEDS_TOP: bool = true;
    const NEEDS_TOPTOP: bool = true;

    fn init_row(&mut self, buffers: &mut [&mut ModularChannel], chan: usize, y: usize) {
        precompute_references(buffers, chan, y, &mut self.references);
        self.property_buffer[2..].fill(0);
    }

    fn decode_one(
        &mut self,
        prediction_data: PredictionData,
        pos: (usize, usize),
        xsize: usize,
        reader: &mut SymbolReader,
        br: &mut BitReader,
        histograms: &Histograms,
    ) -> i32 {
        let prediction_result = predict(
            &self.nodes,
            prediction_data,
            xsize,
            None,
            pos.0,
            pos.1,
            &self.references,
            &mut self.property_buffer,
        );
        let dec = reader.read_signed(histograms, br, prediction_result.context as usize);
        make_pixel(dec, prediction_result.multiplier, prediction_result.guess)
    }
}

pub struct GeneralTree {
    no_wp_tree: NoWpTree,
    wp_state: WeightedPredictorState,
}

impl GeneralTree {
    fn new(
        nodes: Vec<TreeNode>,
        max_property_count: usize,
        header: &GroupHeader,
        channel: usize,
        stream: usize,
        xsize: usize,
    ) -> Result<Self> {
        let wp_state = WeightedPredictorState::new(&header.wp_header, xsize);
        Ok(Self {
            no_wp_tree: NoWpTree::new(nodes, max_property_count, channel, stream, xsize)?,
            wp_state,
        })
    }
}

impl ModularChannelDecoder for GeneralTree {
    const NEEDS_TOP: bool = true;
    const NEEDS_TOPTOP: bool = true;

    fn init_row(&mut self, buffers: &mut [&mut ModularChannel], chan: usize, y: usize) {
        self.no_wp_tree.init_row(buffers, chan, y);
    }

    fn decode_one(
        &mut self,
        prediction_data: PredictionData,
        pos: (usize, usize),
        xsize: usize,
        reader: &mut SymbolReader,
        br: &mut BitReader,
        histograms: &Histograms,
    ) -> i32 {
        let prediction_result = predict(
            &self.no_wp_tree.nodes,
            prediction_data,
            xsize,
            Some(&mut self.wp_state),
            pos.0,
            pos.1,
            &self.no_wp_tree.references,
            &mut self.no_wp_tree.property_buffer,
        );
        let dec = reader.read_signed(histograms, br, prediction_result.context as usize);
        let val = make_pixel(dec, prediction_result.multiplier, prediction_result.guess);
        self.wp_state.update_errors(val, pos, xsize);
        val
    }
}

const LUT_MAX_SPLITVAL: i32 = 1023;
const LUT_MIN_SPLITVAL: i32 = -1024;
const LUT_TABLE_SIZE: usize = (LUT_MAX_SPLITVAL - LUT_MIN_SPLITVAL + 1) as usize;
const _: () = assert!(LUT_TABLE_SIZE.is_power_of_two());

pub struct WpOnlyLookup {
    lut: [u8; LUT_TABLE_SIZE], // Lookup (wp value -> *clustered* context id)
    wp_state: WeightedPredictorState,
}

fn make_lut(tree: &[TreeNode], histograms: &Histograms) -> Option<[u8; LUT_TABLE_SIZE]> {
    struct RangeAndNode {
        range: Range<i32>,
        node: u32,
    }
    let mut stack = vec![RangeAndNode {
        range: LUT_MIN_SPLITVAL..LUT_MAX_SPLITVAL + 1,
        node: 0,
    }];

    let mut ans = [0u8; LUT_TABLE_SIZE];
    while let Some(RangeAndNode { range, node }) = stack.pop() {
        let v = tree[node as usize];
        match v {
            TreeNode::Split {
                val, left, right, ..
            } => {
                let first_left = val + 1;
                if first_left >= range.end || first_left <= range.start {
                    return None;
                }
                stack.push(RangeAndNode {
                    range: first_left..range.end,
                    node: left,
                });
                stack.push(RangeAndNode {
                    range: range.start..first_left,
                    node: right,
                });
            }
            TreeNode::Leaf {
                offset,
                multiplier,
                id,
                ..
            } => {
                if offset != 0 || multiplier != 1 {
                    return None;
                }
                let start = range.start - LUT_MIN_SPLITVAL;
                let end = range.end - LUT_MIN_SPLITVAL;
                ans[start as usize..end as usize]
                    .fill(histograms.map_context_to_cluster(id as usize) as u8);
            }
        }
    }

    Some(ans)
}

impl WpOnlyLookup {
    fn new(
        tree: &[TreeNode],
        histograms: &Histograms,
        header: &GroupHeader,
        xsize: usize,
    ) -> Option<Self> {
        let wp_state = WeightedPredictorState::new(&header.wp_header, xsize);
        let lut = make_lut(tree, histograms)?;
        Some(Self { lut, wp_state })
    }
}

impl ModularChannelDecoder for WpOnlyLookup {
    const NEEDS_TOP: bool = true;
    const NEEDS_TOPTOP: bool = true;

    fn init_row(&mut self, _buffers: &mut [&mut ModularChannel], _chan: usize, _y: usize) {
        // nothing to do
    }

    #[inline(always)]
    fn decode_one(
        &mut self,
        prediction_data: PredictionData,
        pos: (usize, usize),
        xsize: usize,
        reader: &mut SymbolReader,
        br: &mut BitReader,
        histograms: &Histograms,
    ) -> i32 {
        let (wp_pred, property) = self
            .wp_state
            .predict_and_property(pos, xsize, &prediction_data);
        let ctx =
            self.lut[(property - LUT_MIN_SPLITVAL).clamp(0, LUT_TABLE_SIZE as i32 - 1) as usize];
        let dec = reader.read_signed_clustered(histograms, br, ctx as usize);
        let val = dec + wp_pred as i32;
        self.wp_state.update_errors(val, pos, xsize);
        val
    }
}

pub struct SingleGradientOnly {
    ctx: usize,
}

impl ModularChannelDecoder for SingleGradientOnly {
    const NEEDS_TOP: bool = true;
    const NEEDS_TOPTOP: bool = false;

    fn init_row(&mut self, _: &mut [&mut ModularChannel], _: usize, _: usize) {}

    #[inline(always)]
    fn decode_one(
        &mut self,
        prediction_data: PredictionData,
        _: (usize, usize),
        _: usize,
        reader: &mut SymbolReader,
        br: &mut BitReader,
        histograms: &Histograms,
    ) -> i32 {
        let pred = Predictor::Gradient.predict_one(prediction_data, 0);
        let dec = reader.read_signed(histograms, br, self.ctx);
        make_pixel(dec, 1, pred)
    }
}

#[allow(clippy::large_enum_variant)]
pub enum TreeSpecialCase {
    NoWp(NoWpTree),
    WpOnly(WpOnlyLookup),
    SingleGradientOnly(SingleGradientOnly),
    General(GeneralTree),
}

pub fn specialize_tree(
    tree: &Tree,
    channel: usize,
    stream: usize,
    xsize: usize,
    header: &GroupHeader,
) -> Result<TreeSpecialCase> {
    // TODO(veluca): consider skipping the pruning if header.uses_global_tree is true.
    let mut pruned_tree = Vec::new();
    let mut queue = VecDeque::new();
    pruned_tree.try_reserve(tree.nodes.len())?;
    queue.try_reserve(tree.nodes.len())?;
    queue.push_front(0);

    let mut uses_wp = false;
    let mut uses_non_wp = false;

    // Obtain a pruned tree without nodes that are not relevant in the current channel and stream.
    // Proceed in BFS order, so that we know that the children of  anode will be adjacent.
    while let Some(v) = queue.pop_front() {
        let node = tree.nodes[v as usize];
        match node {
            TreeNode::Split {
                property,
                val,
                left,
                right,
            } if property < 2 => {
                // If the node splits on static properties, re-enqueue its correct child immediately.
                let vv = if property == 0 { channel } else { stream };
                queue.push_front(if vv as i32 > val { left } else { right });
                continue;
            }
            TreeNode::Split {
                property,
                val,
                left,
                right,
            } => {
                // WeightedPredictor property.
                uses_wp |= property == 15;
                uses_non_wp |= property != 15;
                let base = (queue.len() + pruned_tree.len() + 1) as u32;
                pruned_tree.push(TreeNode::Split {
                    property,
                    val,
                    left: base,
                    right: base + 1,
                });
                queue.push_back(left);
                queue.push_back(right);
            }
            TreeNode::Leaf { predictor, .. } => {
                uses_wp |= predictor == Predictor::Weighted;
                uses_non_wp |= predictor != Predictor::Weighted;
                pruned_tree.push(node);
            }
        }
    }

    if let [
        TreeNode::Leaf {
            predictor: Predictor::Gradient,
            multiplier: 1,
            offset: 0,
            id,
        },
    ] = &*pruned_tree
    {
        return Ok(TreeSpecialCase::SingleGradientOnly(SingleGradientOnly {
            ctx: *id as usize,
        }));
    }

    if !uses_non_wp
        && let Some(wp) = WpOnlyLookup::new(&pruned_tree, &tree.histograms, header, xsize)
    {
        return Ok(TreeSpecialCase::WpOnly(wp));
    }

    if !uses_wp {
        return Ok(TreeSpecialCase::NoWp(NoWpTree::new(
            pruned_tree,
            tree.max_property_count(),
            channel,
            stream,
            xsize,
        )?));
    }

    Ok(TreeSpecialCase::General(GeneralTree::new(
        pruned_tree,
        tree.max_property_count(),
        header,
        channel,
        stream,
        xsize,
    )?))
}