use ultrahdr_core::gainmap::apply::calculate_weight;
use ultrahdr_core::{GainMapChannel, GainMapMetadata};
const TOLERANCE: f32 = 1e-5;
const GOLDEN_LIBULTRAHDR_APPLY: &str = include_str!("data/libultrahdr_apply_gain.csv");
const GOLDEN_LIBAVIF: &str = include_str!("data/libavif_apply_gain.csv");
const GOLDEN_LIBULTRAHDR_COMPUTE: &str = include_str!("data/libultrahdr_compute_gain.csv");
fn rows(csv: &str) -> impl Iterator<Item = Vec<&str>> {
csv.lines()
.map(str::trim)
.filter(|l| !l.is_empty() && !l.starts_with('#'))
.map(|l| l.split(',').map(str::trim).collect::<Vec<_>>())
}
fn apply_metadata(gain_min_log2: f64, gain_max_log2: f64) -> GainMapMetadata {
let ch = GainMapChannel {
min: gain_min_log2,
max: gain_max_log2,
gamma: 1.0,
base_offset: 1.0 / 64.0,
alternate_offset: 1.0 / 64.0,
};
let mut md = GainMapMetadata::default();
md.channels = [ch; 3];
md.base_hdr_headroom = 0.0;
md.alternate_hdr_headroom = gain_max_log2;
md.use_base_color_space = true;
md
}
fn apply_one(metadata: &GainMapMetadata, weight: f32, gain_norm: f32, base: f32) -> f32 {
let ch = &metadata.channels[0];
let log_min_natural = (ch.min * core::f64::consts::LN_2) as f32;
let log_max_natural = (ch.max * core::f64::consts::LN_2) as f32;
let gamma_undo = if ch.gamma != 1.0 && ch.gamma > 0.0 {
gain_norm.powf(1.0 / ch.gamma as f32)
} else {
gain_norm
};
let log_gain = log_min_natural + gamma_undo * (log_max_natural - log_min_natural);
let linear_gain = (log_gain * weight).exp();
(base + ch.base_offset as f32) * linear_gain - ch.alternate_offset as f32
}
#[test]
fn libultrahdr_apply_gain_parity() {
let metadata = apply_metadata(0.0, 2.0);
let mut tested = 0;
let mut max_err = 0.0f32;
for row in rows(GOLDEN_LIBULTRAHDR_APPLY) {
if row.len() != 6 {
continue;
}
let Ok(weight): Result<f32, _> = row[0].parse() else {
continue;
};
let gain_norm: f32 = row[1].parse().unwrap();
let base: f32 = row[2].parse().unwrap();
let expected_r: f32 = row[3].parse().unwrap();
let expected_g: f32 = row[4].parse().unwrap();
let expected_b: f32 = row[5].parse().unwrap();
let out = apply_one(&metadata, weight, gain_norm, base);
for (label, expected) in [("R", expected_r), ("G", expected_g), ("B", expected_b)] {
let err = (out - expected).abs();
assert!(
err < TOLERANCE,
"libultrahdr apply mismatch on row {row:?} channel {label}: got {out}, want {expected}, err {err}",
);
max_err = max_err.max(err);
}
tested += 1;
}
assert!(tested > 50, "too few rows tested: {tested}");
eprintln!("libultrahdr_apply: {tested} rows, max_err={max_err:.3e}");
}
#[test]
fn libavif_weight_parity() {
let mut tested = 0;
let mut max_err = 0.0f32;
for row in rows(GOLDEN_LIBAVIF) {
if row.len() != 4 {
continue;
}
if row[0] == "base_headroom_log2" {
continue;
}
let Ok(base): Result<f32, _> = row[0].parse() else {
continue;
};
let alt: f32 = row[1].parse().unwrap();
let display: f32 = row[2].parse().unwrap();
let expected: f32 = row[3].parse().unwrap();
let display_boost = (display as f64).exp2() as f32;
let mut metadata = GainMapMetadata::default();
metadata.base_hdr_headroom = base as f64;
metadata.alternate_hdr_headroom = alt as f64;
let got = calculate_weight(display_boost, &metadata);
let err = (got - expected).abs();
assert!(
err < TOLERANCE,
"libavif weight mismatch on row {row:?}: got {got}, want {expected}, err {err}",
);
max_err = max_err.max(err);
tested += 1;
}
assert!(tested >= 5, "too few rows tested: {tested}");
eprintln!("libavif_weight: {tested} rows, max_err={max_err:.3e}");
}
#[test]
fn libavif_apply_gain_parity() {
let metadata = apply_metadata(0.0, 2.0);
let mut tested = 0;
let mut max_err = 0.0f32;
for row in rows(GOLDEN_LIBAVIF) {
if row.len() != 5 {
continue;
}
if row[0] == "gain" {
continue;
}
let Ok(gain_norm): Result<f32, _> = row[0].parse() else {
continue;
};
let base: f32 = row[1].parse().unwrap();
let expected_r: f32 = row[2].parse().unwrap();
let expected_g: f32 = row[3].parse().unwrap();
let expected_b: f32 = row[4].parse().unwrap();
let out = apply_one(&metadata, 1.0, gain_norm, base);
for (label, expected) in [("R", expected_r), ("G", expected_g), ("B", expected_b)] {
let err = (out - expected).abs();
assert!(
err < TOLERANCE,
"libavif apply mismatch on row {row:?} channel {label}: got {out}, want {expected}, err {err}",
);
max_err = max_err.max(err);
}
tested += 1;
}
assert!(tested >= 30, "too few rows tested: {tested}");
eprintln!("libavif_apply: {tested} rows, max_err={max_err:.3e}");
}
#[test]
fn libultrahdr_compute_gain_documented_divergence() {
let two_over_255 = 2.0 / 255.0;
let mut clamped = 0;
let mut unclamped_checked = 0;
for row in rows(GOLDEN_LIBULTRAHDR_COMPUTE) {
if row.len() != 3 {
continue;
}
if row[0] == "sdr" {
continue;
}
let sdr: f32 = row[0].parse().unwrap();
let hdr: f32 = row[1].parse().unwrap();
let gain_log2: f32 = row[2].parse().unwrap();
if sdr < two_over_255 {
if gain_log2 == 2.3f32 {
clamped += 1;
}
continue;
}
let raw = ((hdr + 1e-7) / (sdr + 1e-7)).log2();
let err = (raw - gain_log2).abs();
assert!(
err < 1e-3,
"compute mismatch on row {row:?}: raw={raw}, golden={gain_log2}, err={err}",
);
unclamped_checked += 1;
}
assert!(
unclamped_checked > 30,
"too few unclamped rows: {unclamped_checked}"
);
assert!(clamped > 5, "no clamped rows seen: {clamped}");
eprintln!(
"libultrahdr_compute: {unclamped_checked} unclamped rows match raw formula, \
{clamped} rows demonstrate the libultrahdr-specific 2.3 near-black clamp",
);
}
#[test]
fn cross_check_libultrahdr_libavif_agree_on_apply() {
use std::collections::HashMap;
let metadata = apply_metadata(0.0, 2.0);
let mut avif: HashMap<(u32, u32), f32> = HashMap::new();
for row in rows(GOLDEN_LIBAVIF) {
if row.len() != 5 {
continue;
}
if row[0] == "gain" {
continue;
}
let Ok(g): Result<f32, _> = row[0].parse() else {
continue;
};
let b: f32 = row[1].parse().unwrap();
let r: f32 = row[2].parse().unwrap();
avif.insert((g.to_bits(), b.to_bits()), r);
}
let mut shared = 0;
for row in rows(GOLDEN_LIBULTRAHDR_APPLY) {
if row.len() != 6 {
continue;
}
let Ok(weight): Result<f32, _> = row[0].parse() else {
continue;
};
if (weight - 1.0).abs() > 1e-6 {
continue;
}
let gain: f32 = row[1].parse().unwrap();
let base: f32 = row[2].parse().unwrap();
let uhdr_r: f32 = row[3].parse().unwrap();
if let Some(&avif_r) = avif.get(&(gain.to_bits(), base.to_bits())) {
let err = (uhdr_r - avif_r).abs();
assert!(
err < TOLERANCE,
"uhdr vs avif disagree at gain={gain}, base={base}: uhdr={uhdr_r} avif={avif_r}",
);
let ours = apply_one(&metadata, 1.0, gain, base);
let err_ours = (ours - uhdr_r).abs();
assert!(
err_ours < TOLERANCE,
"ours diverges at gain={gain}, base={base}"
);
shared += 1;
}
}
assert!(shared >= 5, "too few shared (gain,base) points: {shared}");
eprintln!("cross_check: {shared} shared (gain, base) points agree across all three impls");
}