use fast_ssim2::compute_ssimulacra2;
use fast_ssim2::reference_data::{REFERENCE_CASES, ReferenceCase};
use sha2::{Digest, Sha256};
use yuvxyb::{ColorPrimaries, Rgb, TransferCharacteristic};
struct Lcg {
state: u64,
}
impl Lcg {
const fn new(seed: u64) -> Self {
Self { state: seed }
}
fn next_u8(&mut self) -> u8 {
self.state = self
.state
.wrapping_mul(6364136223846793005)
.wrapping_add(1442695040888963407);
((self.state >> 33) & 0xFF) as u8
}
}
fn gen_uniform(width: usize, height: usize, r: u8, g: u8, b: u8) -> Vec<u8> {
vec![r, g, b]
.into_iter()
.cycle()
.take(width * height * 3)
.collect()
}
fn gen_gradient_h(width: usize, height: usize) -> Vec<u8> {
let mut data = Vec::with_capacity(width * height * 3);
for _y in 0..height {
for x in 0..width {
let val = if width > 1 {
(x * 255 / (width - 1)) as u8
} else {
128
};
data.extend_from_slice(&[val, val, val]);
}
}
data
}
fn gen_gradient_v(width: usize, height: usize) -> Vec<u8> {
let mut data = Vec::with_capacity(width * height * 3);
for y in 0..height {
let val = if height > 1 {
(y * 255 / (height - 1)) as u8
} else {
128
};
for _x in 0..width {
data.extend_from_slice(&[val, val, val]);
}
}
data
}
fn gen_gradient_diag(width: usize, height: usize) -> Vec<u8> {
let mut data = Vec::with_capacity(width * height * 3);
let max_dist = width + height - 2;
for y in 0..height {
for x in 0..width {
let val = match ((x + y) * 255).checked_div(max_dist) {
Some(v) => v as u8,
None => 128,
};
data.extend_from_slice(&[val, val, val]);
}
}
data
}
fn gen_checkerboard(width: usize, height: usize, cell_size: usize) -> Vec<u8> {
let mut data = Vec::with_capacity(width * height * 3);
for y in 0..height {
for x in 0..width {
let val = if ((x / cell_size) + (y / cell_size)).is_multiple_of(2) {
255
} else {
0
};
data.extend_from_slice(&[val, val, val]);
}
}
data
}
fn gen_noise(width: usize, height: usize, seed: u64) -> Vec<u8> {
let mut lcg = Lcg::new(seed);
let mut data = Vec::with_capacity(width * height * 3);
for _ in 0..width * height {
data.push(lcg.next_u8());
data.push(lcg.next_u8());
data.push(lcg.next_u8());
}
data
}
fn gen_edge(width: usize, height: usize, vertical: bool) -> Vec<u8> {
let mut data = Vec::with_capacity(width * height * 3);
for y in 0..height {
for x in 0..width {
let val = if vertical {
if x < width / 2 { 0 } else { 255 }
} else if y < height / 2 {
0
} else {
255
};
data.extend_from_slice(&[val, val, val]);
}
}
data
}
fn apply_box_blur_8x8(input: &[u8], width: usize, height: usize) -> Vec<u8> {
let mut output = vec![0u8; width * height * 3];
const KERNEL_SIZE: i32 = 8;
const HALF_KERNEL: i32 = KERNEL_SIZE / 2;
for y in 0..height {
for x in 0..width {
let mut sum = [0u32; 3];
let mut count = 0u32;
for ky in -HALF_KERNEL..HALF_KERNEL {
for kx in -HALF_KERNEL..HALF_KERNEL {
let ny = (y as i32 + ky).clamp(0, height as i32 - 1) as usize;
let nx = (x as i32 + kx).clamp(0, width as i32 - 1) as usize;
let idx = (ny * width + nx) * 3;
sum[0] += input[idx] as u32;
sum[1] += input[idx + 1] as u32;
sum[2] += input[idx + 2] as u32;
count += 1;
}
}
let out_idx = (y * width + x) * 3;
output[out_idx] = (sum[0] / count) as u8;
output[out_idx + 1] = (sum[1] / count) as u8;
output[out_idx + 2] = (sum[2] / count) as u8;
}
}
output
}
fn apply_sharpen(input: &[u8], width: usize, height: usize) -> Vec<u8> {
let mut output = vec![0u8; width * height * 3];
for y in 0..height {
for x in 0..width {
for c in 0..3 {
let idx = (y * width + x) * 3 + c;
let center = input[idx] as i32;
let top = if y > 0 {
input[((y - 1) * width + x) * 3 + c] as i32
} else {
center
};
let bottom = if y < height - 1 {
input[((y + 1) * width + x) * 3 + c] as i32
} else {
center
};
let left = if x > 0 {
input[(y * width + (x - 1)) * 3 + c] as i32
} else {
center
};
let right = if x < width - 1 {
input[(y * width + (x + 1)) * 3 + c] as i32
} else {
center
};
let sharpened = 5 * center - top - bottom - left - right;
output[idx] = sharpened.clamp(0, 255) as u8;
}
}
}
output
}
fn apply_yuv_roundtrip(input: &[u8], width: usize, height: usize) -> Vec<u8> {
let mut output = vec![0u8; width * height * 3];
for i in 0..width * height {
let idx = i * 3;
let r = input[idx] as f32;
let g = input[idx + 1] as f32;
let b = input[idx + 2] as f32;
let y = 0.299 * r + 0.587 * g + 0.114 * b;
let u = -0.14713 * r - 0.28886 * g + 0.436 * b + 128.0;
let v = 0.615 * r - 0.51499 * g - 0.10001 * b + 128.0;
let r_out = y + 1.13983 * (v - 128.0);
let g_out = y - 0.39465 * (u - 128.0) - 0.58060 * (v - 128.0);
let b_out = y + 2.03211 * (u - 128.0);
output[idx] = r_out.clamp(0.0, 255.0) as u8;
output[idx + 1] = g_out.clamp(0.0, 255.0) as u8;
output[idx + 2] = b_out.clamp(0.0, 255.0) as u8;
}
output
}
fn generate_test_image(case: &ReferenceCase) -> (Vec<u8>, Vec<u8>) {
let name = case.name;
let width = case.width;
let height = case.height;
if name.starts_with("perfect_match") {
let data = gen_uniform(width, height, 128, 128, 128);
(data.clone(), data)
} else if let Some(shift_str) = name.strip_prefix("uniform_shift_") {
if let Some(shift) = shift_str
.split('_')
.next()
.and_then(|s| s.parse::<u8>().ok())
{
let source = gen_uniform(width, height, 128, 128, 128);
let distorted = gen_uniform(width, height, 128 + shift, 128 + shift, 128 + shift);
(source, distorted)
} else {
panic!("Invalid uniform_shift test case: {}", name);
}
} else if name.starts_with("gradient_h_") {
let grad = gen_gradient_h(width, height);
(grad.clone(), grad)
} else if name.starts_with("gradient_v_") {
let grad = gen_gradient_v(width, height);
(grad.clone(), grad)
} else if name.starts_with("gradient_diag_") {
let grad = gen_gradient_diag(width, height);
(grad.clone(), grad)
} else if let Some(rest) = name.strip_prefix("checkerboard_") {
if let Some(cell_size) = rest.split('_').next().and_then(|s| s.parse::<usize>().ok()) {
let checker = gen_checkerboard(width, height, cell_size);
(checker.clone(), checker)
} else {
panic!("Invalid checkerboard test case: {}", name);
}
} else if let Some(rest) = name.strip_prefix("noise_seed_") {
if let Some(seed) = rest.split('_').next().and_then(|s| s.parse::<u64>().ok()) {
let noise = gen_noise(width, height, seed);
(noise.clone(), noise)
} else {
panic!("Invalid noise test case: {}", name);
}
} else if name.starts_with("edge_vertical") {
let edge = gen_edge(width, height, true);
(edge.clone(), edge)
} else if name.starts_with("edge_horizontal") {
let edge = gen_edge(width, height, false);
(edge.clone(), edge)
} else if name.contains("gradient_vs_uniform") {
let grad = gen_gradient_h(width, height);
let uniform = gen_uniform(width, height, 128, 128, 128);
(grad, uniform)
} else if name.contains("noise_vs_uniform") {
let noise = gen_noise(width, height, 42);
let uniform = gen_uniform(width, height, 128, 128, 128);
(noise, uniform)
} else if name.contains("gradient_vs_boxblur8x8") {
let source = gen_gradient_h(width, height);
let blurred = apply_box_blur_8x8(&source, width, height);
(source, blurred)
} else if name.contains("noise_vs_sharpen") {
let source = gen_noise(width, height, 999);
let sharpened = apply_sharpen(&source, width, height);
(source, sharpened)
} else if name.contains("gradient_vs_yuv_roundtrip") {
let source = gen_gradient_diag(width, height);
let yuv_roundtrip = apply_yuv_roundtrip(&source, width, height);
(source, yuv_roundtrip)
} else if name.contains("edge_vs_boxblur8x8") {
let source = gen_edge(width, height, true);
let blurred = apply_box_blur_8x8(&source, width, height);
(source, blurred)
} else {
panic!("Unknown test case pattern: {}", name);
}
}
#[test]
fn test_reference_parity() {
assert!(
!REFERENCE_CASES.is_empty(),
"no reference cases compiled in — the C++ parity gate would be vacuous"
);
let mut failures = Vec::new();
let mut max_error = 0.0f64;
for (i, case) in REFERENCE_CASES.iter().enumerate() {
let (source_data, distorted_data) = generate_test_image(case);
let source_hash = Sha256::digest(&source_data)
.iter()
.map(|b| format!("{b:02x}"))
.collect::<String>();
let distorted_hash = Sha256::digest(&distorted_data)
.iter()
.map(|b| format!("{b:02x}"))
.collect::<String>();
if source_hash != case.source_hash {
eprintln!(
"\nERROR: Source image hash mismatch for {}!\nExpected: {}\nGot: {}\nThis indicates the image generation algorithm changed.",
case.name, case.source_hash, source_hash
);
panic!("Image generation changed for {}", case.name);
}
if distorted_hash != case.distorted_hash {
eprintln!(
"\nERROR: Distorted image hash mismatch for {}!\nExpected: {}\nGot: {}\nThis indicates the image generation algorithm changed.",
case.name, case.distorted_hash, distorted_hash
);
panic!("Image generation changed for {}", case.name);
}
let source_rgb: Vec<[f32; 3]> = source_data
.chunks_exact(3)
.map(|c| {
[
c[0] as f32 / 255.0,
c[1] as f32 / 255.0,
c[2] as f32 / 255.0,
]
})
.collect();
let distorted_rgb: Vec<[f32; 3]> = distorted_data
.chunks_exact(3)
.map(|c| {
[
c[0] as f32 / 255.0,
c[1] as f32 / 255.0,
c[2] as f32 / 255.0,
]
})
.collect();
let nz_width = std::num::NonZeroUsize::new(case.width).unwrap();
let nz_height = std::num::NonZeroUsize::new(case.height).unwrap();
let source = Rgb::new(
source_rgb,
nz_width,
nz_height,
TransferCharacteristic::SRGB,
ColorPrimaries::BT709,
)
.unwrap();
let distorted = Rgb::new(
distorted_rgb,
nz_width,
nz_height,
TransferCharacteristic::SRGB,
ColorPrimaries::BT709,
)
.unwrap();
let score = compute_ssimulacra2(source, distorted).unwrap();
let error = (score - case.expected_score).abs();
max_error = max_error.max(error);
let tolerance = if case.name.contains("uniform_shift") {
10.0 } else if case.name.contains("boxblur8x8")
|| case.name.contains("sharpen")
|| case.name.contains("yuv_roundtrip")
{
0.2 } else if case.name.contains("_vs_")
|| case.name.starts_with("perfect_match")
|| case.name.starts_with("gradient_h_")
|| case.name.starts_with("gradient_v_")
|| case.name.starts_with("checkerboard_")
|| case.name.starts_with("noise_seed_")
|| case.name.starts_with("edge_")
{
0.01 } else {
0.05 };
if error > tolerance {
failures.push((i, case.name, case.expected_score, score, error));
}
}
if !failures.is_empty() {
eprintln!(
"\n{} / {} tests FAILED:",
failures.len(),
REFERENCE_CASES.len()
);
eprintln!(
"{:<5} {:<50} {:>15} {:>15} {:>10}",
"Index", "Name", "Expected", "Actual", "Error"
);
eprintln!("{:-<100}", "");
for (i, name, expected, actual, error) in &failures {
eprintln!(
"{:<5} {:<50} {:>15.6} {:>15.6} {:>10.6}",
i, name, expected, actual, error
);
}
eprintln!("\nMax error: {:.6}", max_error);
panic!("{} tests failed", failures.len());
}
#[derive(Debug, Clone)]
struct ErrorCase {
name: &'static str,
expected: f64,
actual: f64,
error: f64,
}
let all_errors: Vec<ErrorCase> = REFERENCE_CASES
.iter()
.map(|case| {
let (source_data, distorted_data) = generate_test_image(case);
let source_rgb: Vec<[f32; 3]> = source_data
.chunks_exact(3)
.map(|c| {
[
c[0] as f32 / 255.0,
c[1] as f32 / 255.0,
c[2] as f32 / 255.0,
]
})
.collect();
let distorted_rgb: Vec<[f32; 3]> = distorted_data
.chunks_exact(3)
.map(|c| {
[
c[0] as f32 / 255.0,
c[1] as f32 / 255.0,
c[2] as f32 / 255.0,
]
})
.collect();
let nz_width = std::num::NonZeroUsize::new(case.width).unwrap();
let nz_height = std::num::NonZeroUsize::new(case.height).unwrap();
let source = Rgb::new(
source_rgb,
nz_width,
nz_height,
TransferCharacteristic::SRGB,
ColorPrimaries::BT709,
)
.unwrap();
let distorted = Rgb::new(
distorted_rgb,
nz_width,
nz_height,
TransferCharacteristic::SRGB,
ColorPrimaries::BT709,
)
.unwrap();
let score = compute_ssimulacra2(source, distorted).unwrap();
ErrorCase {
name: case.name,
expected: case.expected_score,
actual: score,
error: (score - case.expected_score).abs(),
}
})
.collect();
let mut sorted_errors = all_errors.clone();
sorted_errors.sort_by(|a, b| b.error.partial_cmp(&a.error).unwrap());
println!("\n{:=^100}", " REFERENCE PARITY TEST RESULTS ");
println!(
"All {} reference tests passed! Max error: {:.6}",
REFERENCE_CASES.len(),
max_error
);
let mut error_values: Vec<f64> = all_errors.iter().map(|e| e.error).collect();
error_values.sort_by(|a, b| a.partial_cmp(b).unwrap());
println!(
"\nError percentiles: p50={:.4}, p90={:.4}, p95={:.4}, p99={:.4}",
error_values[error_values.len() / 2],
error_values[(error_values.len() * 90) / 100],
error_values[(error_values.len() * 95) / 100],
error_values[(error_values.len() * 99) / 100]
);
println!(
"Errors >0.1: {}, >0.5: {}, >1.0: {}",
error_values.iter().filter(|&&e| e > 0.1).count(),
error_values.iter().filter(|&&e| e > 0.5).count(),
error_values.iter().filter(|&&e| e > 1.0).count()
);
println!("\n{:-^100}", " Top 10 Largest Errors ");
println!(
"{:<50} {:>15} {:>15} {:>10}",
"Test Case", "Expected", "Actual", "Error"
);
println!("{:-<100}", "");
for case in sorted_errors.iter().take(10) {
println!(
"{:<50} {:>15.6} {:>15.6} {:>10.6}",
case.name, case.expected, case.actual, case.error
);
}
println!("\n{:-^100}", " Error Breakdown by Pattern Type ");
let mut pattern_errors: std::collections::HashMap<&str, Vec<f64>> =
std::collections::HashMap::new();
for case in &all_errors {
let pattern = if case.name.contains("uniform_shift") {
"uniform_shift"
} else if case.name.contains("boxblur8x8")
|| case.name.contains("sharpen")
|| case.name.contains("yuv_roundtrip")
{
"distortions"
} else if case.name.contains("_vs_") {
"synthetic_vs"
} else if case.name.starts_with("perfect_match") {
"perfect_match"
} else if case.name.starts_with("gradient") {
"gradients"
} else if case.name.starts_with("checkerboard") {
"checkerboard"
} else if case.name.starts_with("noise_seed") {
"noise"
} else if case.name.starts_with("edge") {
"edges"
} else {
"other"
};
pattern_errors.entry(pattern).or_default().push(case.error);
}
println!(
"{:<20} {:>10} {:>15} {:>15} {:>15}",
"Pattern", "Count", "Max Error", "Mean Error", "P95 Error"
);
println!("{:-<80}", "");
let mut pattern_names: Vec<_> = pattern_errors.keys().copied().collect();
pattern_names.sort();
for pattern in pattern_names {
if let Some(errors) = pattern_errors.get_mut(pattern) {
errors.sort_by(|a, b| a.partial_cmp(b).unwrap());
let max = errors.iter().copied().fold(f64::NEG_INFINITY, f64::max);
let mean = errors.iter().sum::<f64>() / errors.len() as f64;
let p95 = errors[(errors.len() * 95) / 100];
println!(
"{:<20} {:>10} {:>15.6} {:>15.6} {:>15.6}",
pattern,
errors.len(),
max,
mean,
p95
);
}
}
println!("{:=^100}", "");
}