zenjxl-decoder 0.3.8

High performance Rust implementation of a JPEG XL decoder
Documentation
// Copyright (c) the JPEG XL Project Authors. All rights reserved.
//
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.

//! Feature corpus tests: decode all generated JXL feature variants and compare
//! against djxl-decoded PNG references.
//!
//! The feature corpus is generated by `scripts/generate_feature_corpus.py` and
//! lives at `$FEATURE_CORPUS_PATH` (default: `/mnt/v/output/zenjxl-decoder/feature-corpus`).
//!
//! Each `.jxl` file has a matching `.png` sibling produced by djxl + oxipng.

#[cfg(feature = "cms")]
use crate::api::MoxCms;
use crate::api::{
    JxlColorType, JxlDataFormat, JxlDecoder, JxlDecoderOptions, JxlOutputBuffer, JxlPixelFormat,
    ProcessingResult, states,
};
use crate::image::{Image, Rect};

use super::parity::{CONFORMANCE_THRESHOLD_U8, ReferenceImage, compare_u8_buffers};

/// Default location for the feature corpus.
const DEFAULT_CORPUS_PATH: &str = "/mnt/v/output/zenjxl-decoder/feature-corpus";

fn feature_corpus_dir() -> Option<std::path::PathBuf> {
    let path =
        std::env::var("FEATURE_CORPUS_PATH").unwrap_or_else(|_| DEFAULT_CORPUS_PATH.to_string());
    let p = std::path::PathBuf::from(path);
    if p.exists() { Some(p) } else { None }
}

/// Discover all JXL files in the feature corpus that have matching PNG references.
fn discover_feature_corpus_tests() -> Vec<(String, std::path::PathBuf, std::path::PathBuf)> {
    let mut tests = Vec::new();
    let Some(root) = feature_corpus_dir() else {
        return tests;
    };

    fn scan_dir(
        dir: &std::path::Path,
        root: &std::path::Path,
        out: &mut Vec<(String, std::path::PathBuf, std::path::PathBuf)>,
    ) {
        let Ok(entries) = std::fs::read_dir(dir) else {
            return;
        };
        for entry in entries.flatten() {
            let path = entry.path();
            if path.is_dir() {
                scan_dir(&path, root, out);
            } else if path.extension().and_then(|e| e.to_str()) == Some("jxl") {
                let png = path.with_extension("png");
                if png.exists() {
                    let rel = path.strip_prefix(root).unwrap_or(&path);
                    let name = rel.to_string_lossy().to_string();
                    out.push((name, path, png));
                }
            }
        }
    }

    scan_dir(&root, &root, &mut tests);
    tests.sort_by(|a, b| a.0.cmp(&b.0));
    tests
}

/// Decode a JXL file to u8 pixels.
fn decode_jxl(path: &std::path::Path) -> Result<(usize, usize, usize, Vec<u8>), String> {
    let data = std::fs::read(path).map_err(|e| format!("read: {e}"))?;
    let mut input = data.as_slice();

    #[cfg(feature = "cms")]
    let options = JxlDecoderOptions {
        cms: Some(Box::new(MoxCms::new())),
        ..JxlDecoderOptions::default()
    };
    #[cfg(not(feature = "cms"))]
    let options = JxlDecoderOptions::default();

    let mut decoder = JxlDecoder::<states::Initialized>::new(options);

    // Header
    let mut decoder = loop {
        match decoder.process(&mut input) {
            Ok(ProcessingResult::Complete { result }) => break result,
            Ok(ProcessingResult::NeedsMoreInput { fallback, .. }) => {
                if input.is_empty() {
                    return Err("EOF during header".into());
                }
                decoder = fallback;
            }
            Err(e) => return Err(format!("header: {e:?}")),
        }
    };

    let info = decoder.basic_info().clone();
    let (width, height) = info.size;

    let default_format = decoder.current_pixel_format();
    let is_gray = matches!(
        default_format.color_type,
        JxlColorType::Grayscale | JxlColorType::GrayscaleAlpha
    );
    let has_alpha = info.extra_channels.iter().any(|ec| {
        matches!(
            ec.ec_type,
            crate::headers::extra_channels::ExtraChannel::Alpha
        )
    });

    let (color_type, channels) = match (is_gray, has_alpha) {
        (true, true) => (JxlColorType::GrayscaleAlpha, 2),
        (true, false) => (JxlColorType::Grayscale, 1),
        (false, true) => (JxlColorType::Rgba, 4),
        (false, false) => (JxlColorType::Rgb, 3),
    };

    let num_ec = info.extra_channels.len();
    let pixel_format = JxlPixelFormat {
        color_type,
        color_data_format: Some(JxlDataFormat::U8 { bit_depth: 8 }),
        extra_channel_format: vec![None; num_ec],
    };
    decoder.set_pixel_format(pixel_format);

    // Frame info
    let mut decoder = loop {
        match decoder.process(&mut input) {
            Ok(ProcessingResult::Complete { result }) => break result,
            Ok(ProcessingResult::NeedsMoreInput { fallback, .. }) => {
                if input.is_empty() {
                    return Err("EOF before frame".into());
                }
                decoder = fallback;
            }
            Err(e) => return Err(format!("frame info: {e:?}")),
        }
    };

    // Decode
    let mut output =
        Image::<u8>::new((width * channels, height)).map_err(|e| format!("buffer: {e:?}"))?;
    let mut buffers = vec![JxlOutputBuffer::from_image_rect_mut(
        output
            .get_rect_mut(Rect {
                origin: (0, 0),
                size: (width * channels, height),
            })
            .into_raw(),
    )];

    loop {
        match decoder.process(&mut input, &mut buffers) {
            Ok(ProcessingResult::Complete { .. }) => break,
            Ok(ProcessingResult::NeedsMoreInput { fallback, .. }) => {
                if input.is_empty() {
                    return Err("EOF during frame".into());
                }
                decoder = fallback;
            }
            Err(e) => return Err(format!("decode: {e:?}")),
        }
    }

    let mut pixels = Vec::with_capacity(width * height * channels);
    for y in 0..height {
        pixels.extend_from_slice(output.row(y));
    }
    Ok((width, height, channels, pixels))
}

/// Compare decoded pixels against reference PNG.
fn compare_with_reference(
    jxl_path: &std::path::Path,
    png_path: &std::path::Path,
) -> Result<(), String> {
    let (width, height, channels, actual) = decode_jxl(jxl_path)?;

    let reference = ReferenceImage::load(png_path).map_err(|e| format!("ref load: {e}"))?;

    if width != reference.width || height != reference.height {
        return Err(format!(
            "size mismatch: {}x{} vs {}x{}",
            width, height, reference.width, reference.height
        ));
    }

    // Handle channel count mismatch (RGBA vs RGB, etc.)
    let (cmp_ch, ref_px, act_px) = if channels == reference.channels {
        (channels, &reference.pixels[..], &actual[..])
    } else if channels == 4 && reference.channels == 3 {
        // Drop alpha from our decode for comparison
        let rgb: Vec<u8> = actual
            .chunks_exact(4)
            .flat_map(|c| &c[..3])
            .copied()
            .collect();
        return compare_buffers_owned(&reference.pixels, &rgb, width, height, 3);
    } else if channels == 3 && reference.channels == 4 {
        let rgb: Vec<u8> = reference
            .pixels
            .chunks_exact(4)
            .flat_map(|c| &c[..3])
            .copied()
            .collect();
        return compare_buffers_owned(&rgb, &actual, width, height, 3);
    } else if channels == 2 && reference.channels == 4 {
        // Gray+alpha decoded, RGBA reference (djxl expands gray to RGBA)
        // Compare just the luminance
        let ref_gray: Vec<u8> = reference.pixels.chunks_exact(4).map(|c| c[0]).collect();
        let act_gray: Vec<u8> = actual.chunks_exact(2).map(|c| c[0]).collect();
        return compare_buffers_owned(&ref_gray, &act_gray, width, height, 1);
    } else if channels == 1 && reference.channels == 3 {
        // Grayscale decoded, RGB reference — compare against R channel
        let ref_gray: Vec<u8> = reference.pixels.chunks_exact(3).map(|c| c[0]).collect();
        return compare_buffers_owned(&ref_gray, &actual, width, height, 1);
    } else if channels == 1 && reference.channels == 4 {
        let ref_gray: Vec<u8> = reference.pixels.chunks_exact(4).map(|c| c[0]).collect();
        return compare_buffers_owned(&ref_gray, &actual, width, height, 1);
    } else {
        return Err(format!(
            "channel mismatch: {} vs {}",
            channels, reference.channels
        ));
    };

    let result = compare_u8_buffers(
        ref_px,
        act_px,
        width,
        height,
        cmp_ch,
        CONFORMANCE_THRESHOLD_U8,
    );
    if result.passed {
        Ok(())
    } else {
        Err(format!(
            "max_err={}, errs={}/{}",
            result.max_abs_error, result.error_count, result.total_pixels
        ))
    }
}

fn compare_buffers_owned(
    reference: &[u8],
    actual: &[u8],
    width: usize,
    height: usize,
    channels: usize,
) -> Result<(), String> {
    let result = compare_u8_buffers(
        reference,
        actual,
        width,
        height,
        channels,
        CONFORMANCE_THRESHOLD_U8,
    );
    if result.passed {
        Ok(())
    } else {
        Err(format!(
            "max_err={}, errs={}/{}",
            result.max_abs_error, result.error_count, result.total_pixels
        ))
    }
}

#[cfg(test)]
mod tests {
    use super::*;

    #[test]
    fn test_feature_corpus_discovery() {
        let tests = discover_feature_corpus_tests();
        if tests.is_empty() {
            eprintln!("Feature corpus not found at {DEFAULT_CORPUS_PATH}");
            eprintln!("Set FEATURE_CORPUS_PATH or run scripts/generate_feature_corpus.py");
            return;
        }
        eprintln!("Found {} feature corpus test cases", tests.len());
    }

    /// Decode all feature corpus JXL files and compare against djxl reference PNGs.
    ///
    /// Run: cargo test --features cms feature_corpus::tests::test_all_feature_corpus -- --ignored --nocapture 2>&1 | tee /tmp/feature-corpus-test.log
    #[test]
    #[ignore]
    fn test_all_feature_corpus() {
        let tests = discover_feature_corpus_tests();
        if tests.is_empty() {
            panic!(
                "No feature corpus tests found. Run scripts/generate_feature_corpus.py or set FEATURE_CORPUS_PATH."
            );
        }

        eprintln!("Testing {} feature corpus files...\n", tests.len());

        let mut passed = 0usize;
        let mut failed = 0usize;
        let mut crashed = 0usize;
        let mut failures: Vec<(String, String)> = Vec::new();

        for (name, jxl_path, png_path) in &tests {
            let result = std::panic::catch_unwind(std::panic::AssertUnwindSafe(|| {
                compare_with_reference(jxl_path, png_path)
            }));

            match result {
                Ok(Ok(())) => {
                    passed += 1;
                }
                Ok(Err(e)) => {
                    eprintln!("FAIL: {name}{e}");
                    failures.push((name.clone(), e));
                    failed += 1;
                }
                Err(panic_info) => {
                    let msg = if let Some(s) = panic_info.downcast_ref::<String>() {
                        s.clone()
                    } else if let Some(s) = panic_info.downcast_ref::<&str>() {
                        s.to_string()
                    } else {
                        "unknown panic".to_string()
                    };
                    eprintln!("CRASH: {name}{msg}");
                    failures.push((name.clone(), format!("PANIC: {msg}")));
                    crashed += 1;
                }
            }

            let total = passed + failed + crashed;
            if total.is_multiple_of(100) {
                eprintln!(
                    "[{total}/{}] {passed} pass, {failed} fail, {crashed} crash",
                    tests.len()
                );
            }
        }

        eprintln!();
        eprintln!("=== Feature Corpus Results ===");
        eprintln!("Passed:  {passed}");
        eprintln!("Failed:  {failed}");
        eprintln!("Crashed: {crashed}");
        eprintln!("Total:   {}", tests.len());

        if !failures.is_empty() {
            eprintln!();
            eprintln!("=== Failures ===");
            for (name, err) in &failures {
                eprintln!("  {name}: {err}");
            }
        }

        if failed + crashed > 0 {
            panic!(
                "{} feature corpus tests failed ({failed} error, {crashed} crash)",
                failed + crashed
            );
        }
    }

    /// Regression: lossless grayscale+alpha modular image previously panicked in
    /// get_buffer (remaining_uses underflow). Fixed by checking buffer ownership
    /// in flush_output to skip channels with aliased buffer indices.
    #[test]
    fn test_gray_alpha_lossless() {
        let path = std::path::Path::new("resources/test/gray_alpha_lossless.jxl");
        decode_jxl(path).expect("gray_alpha_lossless.jxl should decode successfully");
    }
}