zenjxl-decoder 0.3.8

High performance Rust implementation of a JPEG XL decoder
Documentation
// Copyright (c) the JPEG XL Project Authors. All rights reserved.
//
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.

use std::borrow::Cow;

use crate::bit_reader::BitReader;
use crate::entropy_coding::decode::Histograms;
use crate::entropy_coding::decode::SymbolReader;
use crate::error::{Error, Result};
use crate::util::{
    CeilLog2, MemoryTracker, NewWithCapacity, tracing_wrappers::instrument, value_of_lowest_1_bit,
};

#[derive(Debug, PartialEq, Default, Clone)]
pub struct Permutation(pub Cow<'static, [u32]>);

impl std::ops::Deref for Permutation {
    type Target = [u32];

    fn deref(&self) -> &[u32] {
        &self.0
    }
}

impl Permutation {
    /// Decode a permutation from entropy-coded stream.
    pub fn decode(
        size: u32,
        skip: u32,
        histograms: &Histograms,
        br: &mut BitReader,
        entropy_reader: &mut SymbolReader,
        memory_tracker: &MemoryTracker,
    ) -> Result<Self> {
        let end = entropy_reader.read_unsigned(histograms, br, get_context(size));
        Self::decode_inner(size, skip, end, memory_tracker, |ctx| -> Result<u32> {
            let r = entropy_reader.read_unsigned(histograms, br, ctx);
            br.check_for_error()?;
            Ok(r)
        })
    }

    fn decode_inner(
        size: u32,
        skip: u32,
        end: u32,
        memory_tracker: &MemoryTracker,
        mut read: impl FnMut(usize) -> Result<u32>,
    ) -> Result<Self> {
        if end > size - skip {
            return Err(Error::InvalidPermutationSize { size, skip, end });
        }

        // Check memory budget for permutation allocations (lehmer + permutation vectors)
        memory_tracker.check_alloc((size as u64) * 4)?;
        let mut lehmer = Vec::new_with_capacity(end as usize)?;

        let mut prev_val = 0u32;
        for idx in skip..(skip + end) {
            let val = match read(get_context(prev_val)) {
                Ok(val) => val,
                Err(Error::OutOfBounds(_)) => {
                    // Estimate 1.5 bits for each remaining code
                    let bits = (((skip + end) - idx) as usize).saturating_mul(3) / 2;
                    return Err(Error::OutOfBounds(bits));
                }
                Err(e) => return Err(e),
            };
            if val >= size - idx {
                return Err(Error::InvalidPermutationLehmerCode {
                    size,
                    idx,
                    lehmer: val,
                });
            }
            lehmer.push(val);
            prev_val = val;
        }

        // Initialize the full permutation vector with skipped elements intact
        let mut permutation = Vec::new_with_capacity((size - skip) as usize)?;
        permutation.extend(0..size);

        // Decode the Lehmer code into the slice starting at `skip`
        let permuted_slice = decode_lehmer_code(&lehmer, &permutation[skip as usize..])?;

        // Replace the target slice in `permutation`
        permutation[skip as usize..].copy_from_slice(&permuted_slice);

        // Ensure the permutation has the correct size
        assert_eq!(permutation.len(), size as usize);

        Ok(Self(Cow::Owned(permutation)))
    }

    pub fn compose(&mut self, other: &Permutation) {
        assert_eq!(self.0.len(), other.0.len());
        let mut tmp: Vec<u32> = vec![0; self.0.len()];
        for (i, val) in tmp.iter_mut().enumerate().take(self.0.len()) {
            *val = self.0[other.0[i] as usize]
        }
        self.0.to_mut().copy_from_slice(&tmp[..]);
    }
}

// Decodes the Lehmer code in `code` and returns the permuted slice.
#[instrument(level = "debug", ret, err)]
fn decode_lehmer_code(code: &[u32], permutation_slice: &[u32]) -> Result<Vec<u32>> {
    let n = permutation_slice.len();
    if n == 0 {
        return Err(Error::InvalidPermutationLehmerCode {
            size: 0,
            idx: 0,
            lehmer: 0,
        });
    }

    let mut permuted = Vec::new_with_capacity(n)?;
    permuted.extend_from_slice(permutation_slice);

    let padded_n = (n as u32).next_power_of_two() as usize;

    // Allocate temp array inside the function
    let mut temp = Vec::new_with_capacity(padded_n)?;
    temp.extend((0..padded_n as u32).map(|x| value_of_lowest_1_bit(x + 1)));

    for (i, permuted_item) in permuted.iter_mut().enumerate() {
        let code_i = *code.get(i).unwrap_or(&0);

        // Adjust the maximum allowed value for code_i
        if code_i as usize > n - i - 1 {
            return Err(Error::InvalidPermutationLehmerCode {
                size: n as u32,
                idx: i as u32,
                lehmer: code_i,
            });
        }

        let mut rank = code_i + 1;

        // Extract i-th unused element via implicit order-statistics tree.
        let mut bit = padded_n;
        let mut next = 0usize;
        while bit != 0 {
            let cand = next + bit;
            if cand == 0 || cand > padded_n {
                return Err(Error::InvalidPermutationLehmerCode {
                    size: n as u32,
                    idx: i as u32,
                    lehmer: code_i,
                });
            }
            bit >>= 1;
            if temp[cand - 1] < rank {
                next = cand;
                rank -= temp[cand - 1];
            }
        }

        *permuted_item = permutation_slice[next];

        next += 1;
        while next <= padded_n {
            temp[next - 1] -= 1;
            next += value_of_lowest_1_bit(next as u32) as usize;
        }
    }

    Ok(permuted)
}

// Decodes the Lehmer code in `code` and returns the permuted vector.
#[cfg(test)]
fn decode_lehmer_code_naive(code: &[u32], permutation_slice: &[u32]) -> Result<Vec<u32>> {
    let n = code.len();
    if n == 0 {
        return Err(Error::InvalidPermutationLehmerCode {
            size: 0,
            idx: 0,
            lehmer: 0,
        });
    }

    // Ensure permutation_slice has sufficient length
    if permutation_slice.len() < n {
        return Err(Error::InvalidPermutationLehmerCode {
            size: n as u32,
            idx: 0,
            lehmer: 0,
        });
    }

    // Create temp array with values from permutation_slice
    let mut temp = permutation_slice.to_vec();
    let mut permuted = Vec::new_with_capacity(n)?;

    // Iterate over the Lehmer code
    for (i, &idx) in code.iter().enumerate() {
        if idx as usize >= temp.len() {
            return Err(Error::InvalidPermutationLehmerCode {
                size: n as u32,
                idx: i as u32,
                lehmer: idx,
            });
        }

        // Assign temp[idx] to permuted vector
        permuted.push(temp.remove(idx as usize));
    }

    // Append any remaining elements from temp to permuted
    permuted.extend(temp);

    Ok(permuted)
}

fn get_context(x: u32) -> usize {
    (x + 1).ceil_log2().min(7) as usize
}

#[cfg(test)]
mod test {
    use super::*;
    use arbtest::arbitrary::{self, Arbitrary, Unstructured};
    use core::assert_eq;
    use test_log::test;

    #[test]
    fn generate_permutation_arbtest() {
        arbtest::arbtest(|u| {
            let input = PermutationInput::arbitrary(u)?;

            let permutation_slice = input.permutation.as_slice();

            let perm1 = decode_lehmer_code(&input.code, permutation_slice);
            let perm2 = decode_lehmer_code_naive(&input.code, permutation_slice);

            assert_eq!(
                perm1.map_err(|x| x.to_string()),
                perm2.map_err(|x| x.to_string())
            );
            Ok(())
        });
    }

    #[derive(Debug)]
    struct PermutationInput {
        code: Vec<u32>,
        permutation: Vec<u32>,
    }

    impl<'a> Arbitrary<'a> for PermutationInput {
        fn arbitrary(u: &mut Unstructured<'a>) -> Result<Self, arbitrary::Error> {
            // Generate a reasonable size to prevent tests from taking too long
            let size_lehmer = u.int_in_range(1..=1000)?;

            let mut lehmer: Vec<u32> = Vec::with_capacity(size_lehmer as usize);
            for i in 0..size_lehmer {
                let max_val = size_lehmer - i - 1;
                let val = if max_val > 0 {
                    u.int_in_range(0..=max_val)?
                } else {
                    0
                };
                lehmer.push(val);
            }

            let mut permutation = Vec::new();
            let size_permutation = u.int_in_range(size_lehmer..=1000)?;
            permutation.extend(0..size_permutation);

            let num_of_swaps = u.int_in_range(0..=100)?;
            for _ in 0..num_of_swaps {
                // Randomly swap two positions
                let pos1 = u.int_in_range(0..=size_permutation - 1)?;
                let pos2 = u.int_in_range(0..=size_permutation - 1)?;
                permutation.swap(pos1 as usize, pos2 as usize);
            }

            Ok(PermutationInput {
                code: lehmer,
                permutation,
            })
        }
    }

    #[test]
    fn simple() {
        // Lehmer code: [1, 1, 2, 3, 3, 6, 0, 1]
        let code = vec![1u32, 1, 2, 3, 3, 6, 0, 1];
        let skip = 4;
        let size = 16;

        let permutation_slice: Vec<u32> = (skip..size).collect();

        let permuted = decode_lehmer_code(&code, &permutation_slice).unwrap();
        let permuted_naive = decode_lehmer_code_naive(&code, &permutation_slice).unwrap();

        let mut permutation = Vec::with_capacity(size as usize);
        permutation.extend(0..skip); // Add skipped elements
        permutation.extend(permuted.iter());
        let expected_permutation = vec![0, 1, 2, 3, 5, 6, 8, 10, 11, 15, 4, 9, 7, 12, 13, 14];

        assert_eq!(permutation, expected_permutation);
        assert_eq!(permuted, permuted_naive);
    }

    #[test]
    fn decode_lehmer_compare_different_length() -> Result<(), Box<dyn std::error::Error>> {
        // Lehmer code: [1, 1, 2, 3, 3, 6, 0, 1]
        let code = vec![1u32, 1, 2, 3, 3, 6, 0, 1];
        let skip = 4;
        let size = 16;

        let permutation_slice: Vec<u32> = (skip..size).collect();

        let permuted_optimized = decode_lehmer_code(&code, &permutation_slice)?;
        let permuted_naive = decode_lehmer_code_naive(&code, &permutation_slice)?;

        let expected_permuted = vec![5u32, 6, 8, 10, 11, 15, 4, 9, 7, 12, 13, 14];

        assert_eq!(permuted_optimized, expected_permuted);
        assert_eq!(permuted_naive, expected_permuted);
        assert_eq!(permuted_optimized, permuted_naive);

        Ok(())
    }

    #[test]
    fn decode_lehmer_compare_same_length() -> Result<(), Box<dyn std::error::Error>> {
        // Lehmer code: [2, 3, 0, 0, 0]
        let code = vec![2u32, 3, 0, 0, 0];
        let n = code.len();
        let permutation_slice: Vec<u32> = (0..n as u32).collect();

        let permuted_optimized = decode_lehmer_code(&code, &permutation_slice)?;
        let permuted_naive = decode_lehmer_code_naive(&code, &permutation_slice)?;

        let expected_permutation = vec![2u32, 4, 0, 1, 3];

        assert_eq!(permuted_optimized, expected_permutation);
        assert_eq!(permuted_naive, expected_permutation);
        assert_eq!(permuted_optimized, permuted_naive);

        Ok(())
    }

    #[test]
    fn lehmer_out_of_bounds() {
        let code = vec![4];
        let permutation_slice: Vec<u32> = (4..8).collect();

        let result = decode_lehmer_code(&code, &permutation_slice);
        assert!(result.is_err());
    }
}