#[cfg(feature = "cms")]
use crate::api::MoxCms;
use crate::api::{
JxlColorType, JxlDataFormat, JxlDecoder, JxlDecoderOptions, JxlOutputBuffer, JxlPixelFormat,
ProcessingResult, states,
};
use crate::image::{Image, Rect};
use super::parity::{CONFORMANCE_THRESHOLD_U8, ReferenceImage, compare_u8_buffers};
const DEFAULT_CORPUS_PATH: &str = "/mnt/v/output/zenjxl-decoder/feature-corpus";
fn feature_corpus_dir() -> Option<std::path::PathBuf> {
let path =
std::env::var("FEATURE_CORPUS_PATH").unwrap_or_else(|_| DEFAULT_CORPUS_PATH.to_string());
let p = std::path::PathBuf::from(path);
if p.exists() { Some(p) } else { None }
}
fn discover_feature_corpus_tests() -> Vec<(String, std::path::PathBuf, std::path::PathBuf)> {
let mut tests = Vec::new();
let Some(root) = feature_corpus_dir() else {
return tests;
};
fn scan_dir(
dir: &std::path::Path,
root: &std::path::Path,
out: &mut Vec<(String, std::path::PathBuf, std::path::PathBuf)>,
) {
let Ok(entries) = std::fs::read_dir(dir) else {
return;
};
for entry in entries.flatten() {
let path = entry.path();
if path.is_dir() {
scan_dir(&path, root, out);
} else if path.extension().and_then(|e| e.to_str()) == Some("jxl") {
let png = path.with_extension("png");
if png.exists() {
let rel = path.strip_prefix(root).unwrap_or(&path);
let name = rel.to_string_lossy().to_string();
out.push((name, path, png));
}
}
}
}
scan_dir(&root, &root, &mut tests);
tests.sort_by(|a, b| a.0.cmp(&b.0));
tests
}
fn decode_jxl(path: &std::path::Path) -> Result<(usize, usize, usize, Vec<u8>), String> {
let data = std::fs::read(path).map_err(|e| format!("read: {e}"))?;
let mut input = data.as_slice();
#[cfg(feature = "cms")]
let options = JxlDecoderOptions {
cms: Some(Box::new(MoxCms::new())),
..JxlDecoderOptions::default()
};
#[cfg(not(feature = "cms"))]
let options = JxlDecoderOptions::default();
let mut decoder = JxlDecoder::<states::Initialized>::new(options);
let mut decoder = loop {
match decoder.process(&mut input) {
Ok(ProcessingResult::Complete { result }) => break result,
Ok(ProcessingResult::NeedsMoreInput { fallback, .. }) => {
if input.is_empty() {
return Err("EOF during header".into());
}
decoder = fallback;
}
Err(e) => return Err(format!("header: {e:?}")),
}
};
let info = decoder.basic_info().clone();
let (width, height) = info.size;
let default_format = decoder.current_pixel_format();
let is_gray = matches!(
default_format.color_type,
JxlColorType::Grayscale | JxlColorType::GrayscaleAlpha
);
let has_alpha = info.extra_channels.iter().any(|ec| {
matches!(
ec.ec_type,
crate::headers::extra_channels::ExtraChannel::Alpha
)
});
let (color_type, channels) = match (is_gray, has_alpha) {
(true, true) => (JxlColorType::GrayscaleAlpha, 2),
(true, false) => (JxlColorType::Grayscale, 1),
(false, true) => (JxlColorType::Rgba, 4),
(false, false) => (JxlColorType::Rgb, 3),
};
let num_ec = info.extra_channels.len();
let pixel_format = JxlPixelFormat {
color_type,
color_data_format: Some(JxlDataFormat::U8 { bit_depth: 8 }),
extra_channel_format: vec![None; num_ec],
};
decoder.set_pixel_format(pixel_format);
let mut decoder = loop {
match decoder.process(&mut input) {
Ok(ProcessingResult::Complete { result }) => break result,
Ok(ProcessingResult::NeedsMoreInput { fallback, .. }) => {
if input.is_empty() {
return Err("EOF before frame".into());
}
decoder = fallback;
}
Err(e) => return Err(format!("frame info: {e:?}")),
}
};
let mut output =
Image::<u8>::new((width * channels, height)).map_err(|e| format!("buffer: {e:?}"))?;
let mut buffers = vec![JxlOutputBuffer::from_image_rect_mut(
output
.get_rect_mut(Rect {
origin: (0, 0),
size: (width * channels, height),
})
.into_raw(),
)];
loop {
match decoder.process(&mut input, &mut buffers) {
Ok(ProcessingResult::Complete { .. }) => break,
Ok(ProcessingResult::NeedsMoreInput { fallback, .. }) => {
if input.is_empty() {
return Err("EOF during frame".into());
}
decoder = fallback;
}
Err(e) => return Err(format!("decode: {e:?}")),
}
}
let mut pixels = Vec::with_capacity(width * height * channels);
for y in 0..height {
pixels.extend_from_slice(output.row(y));
}
Ok((width, height, channels, pixels))
}
fn compare_with_reference(
jxl_path: &std::path::Path,
png_path: &std::path::Path,
) -> Result<(), String> {
let (width, height, channels, actual) = decode_jxl(jxl_path)?;
let reference = ReferenceImage::load(png_path).map_err(|e| format!("ref load: {e}"))?;
if width != reference.width || height != reference.height {
return Err(format!(
"size mismatch: {}x{} vs {}x{}",
width, height, reference.width, reference.height
));
}
let (cmp_ch, ref_px, act_px) = if channels == reference.channels {
(channels, &reference.pixels[..], &actual[..])
} else if channels == 4 && reference.channels == 3 {
let rgb: Vec<u8> = actual
.chunks_exact(4)
.flat_map(|c| &c[..3])
.copied()
.collect();
return compare_buffers_owned(&reference.pixels, &rgb, width, height, 3);
} else if channels == 3 && reference.channels == 4 {
let rgb: Vec<u8> = reference
.pixels
.chunks_exact(4)
.flat_map(|c| &c[..3])
.copied()
.collect();
return compare_buffers_owned(&rgb, &actual, width, height, 3);
} else if channels == 2 && reference.channels == 4 {
let ref_gray: Vec<u8> = reference.pixels.chunks_exact(4).map(|c| c[0]).collect();
let act_gray: Vec<u8> = actual.chunks_exact(2).map(|c| c[0]).collect();
return compare_buffers_owned(&ref_gray, &act_gray, width, height, 1);
} else if channels == 1 && reference.channels == 3 {
let ref_gray: Vec<u8> = reference.pixels.chunks_exact(3).map(|c| c[0]).collect();
return compare_buffers_owned(&ref_gray, &actual, width, height, 1);
} else if channels == 1 && reference.channels == 4 {
let ref_gray: Vec<u8> = reference.pixels.chunks_exact(4).map(|c| c[0]).collect();
return compare_buffers_owned(&ref_gray, &actual, width, height, 1);
} else {
return Err(format!(
"channel mismatch: {} vs {}",
channels, reference.channels
));
};
let result = compare_u8_buffers(
ref_px,
act_px,
width,
height,
cmp_ch,
CONFORMANCE_THRESHOLD_U8,
);
if result.passed {
Ok(())
} else {
Err(format!(
"max_err={}, errs={}/{}",
result.max_abs_error, result.error_count, result.total_pixels
))
}
}
fn compare_buffers_owned(
reference: &[u8],
actual: &[u8],
width: usize,
height: usize,
channels: usize,
) -> Result<(), String> {
let result = compare_u8_buffers(
reference,
actual,
width,
height,
channels,
CONFORMANCE_THRESHOLD_U8,
);
if result.passed {
Ok(())
} else {
Err(format!(
"max_err={}, errs={}/{}",
result.max_abs_error, result.error_count, result.total_pixels
))
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_feature_corpus_discovery() {
let tests = discover_feature_corpus_tests();
if tests.is_empty() {
eprintln!("Feature corpus not found at {DEFAULT_CORPUS_PATH}");
eprintln!("Set FEATURE_CORPUS_PATH or run scripts/generate_feature_corpus.py");
return;
}
eprintln!("Found {} feature corpus test cases", tests.len());
}
#[test]
#[ignore]
fn test_all_feature_corpus() {
let tests = discover_feature_corpus_tests();
if tests.is_empty() {
panic!(
"No feature corpus tests found. Run scripts/generate_feature_corpus.py or set FEATURE_CORPUS_PATH."
);
}
eprintln!("Testing {} feature corpus files...\n", tests.len());
let mut passed = 0usize;
let mut failed = 0usize;
let mut crashed = 0usize;
let mut failures: Vec<(String, String)> = Vec::new();
for (name, jxl_path, png_path) in &tests {
let result = std::panic::catch_unwind(std::panic::AssertUnwindSafe(|| {
compare_with_reference(jxl_path, png_path)
}));
match result {
Ok(Ok(())) => {
passed += 1;
}
Ok(Err(e)) => {
eprintln!("FAIL: {name} — {e}");
failures.push((name.clone(), e));
failed += 1;
}
Err(panic_info) => {
let msg = if let Some(s) = panic_info.downcast_ref::<String>() {
s.clone()
} else if let Some(s) = panic_info.downcast_ref::<&str>() {
s.to_string()
} else {
"unknown panic".to_string()
};
eprintln!("CRASH: {name} — {msg}");
failures.push((name.clone(), format!("PANIC: {msg}")));
crashed += 1;
}
}
let total = passed + failed + crashed;
if total.is_multiple_of(100) {
eprintln!(
"[{total}/{}] {passed} pass, {failed} fail, {crashed} crash",
tests.len()
);
}
}
eprintln!();
eprintln!("=== Feature Corpus Results ===");
eprintln!("Passed: {passed}");
eprintln!("Failed: {failed}");
eprintln!("Crashed: {crashed}");
eprintln!("Total: {}", tests.len());
if !failures.is_empty() {
eprintln!();
eprintln!("=== Failures ===");
for (name, err) in &failures {
eprintln!(" {name}: {err}");
}
}
if failed + crashed > 0 {
panic!(
"{} feature corpus tests failed ({failed} error, {crashed} crash)",
failed + crashed
);
}
}
#[test]
fn test_gray_alpha_lossless() {
let path = std::path::Path::new("resources/test/gray_alpha_lossless.jxl");
decode_jxl(path).expect("gray_alpha_lossless.jxl should decode successfully");
}
}