use crate::Bt2408Tonemapper;
use crate::curves::{filmic_narkowicz, hable_filmic, reinhard_extended, reinhard_simple};
#[derive(Debug, Clone)]
#[non_exhaustive]
pub struct DetectedCurve {
pub name: &'static str,
pub rms_error: f32,
pub params: DetectedParams,
}
#[derive(Debug, Clone)]
#[non_exhaustive]
pub enum DetectedParams {
None,
ExtendedReinhard {
l_max: f32,
},
Bt2408 {
content_nits: f32,
display_nits: f32,
},
}
pub fn detect_standard(lut: &[f32], max_hdr: f32, threshold: f32) -> Option<DetectedCurve> {
let n = lut.len();
if n < 256 {
return None;
}
let mut best: Option<DetectedCurve> = None;
type ScalarCurve = fn(f32) -> f32;
let parameterless: &[(&str, ScalarCurve)] = &[
("Reinhard", reinhard_simple as fn(f32) -> f32),
("Narkowicz", filmic_narkowicz),
("HableFilmic", hable_filmic),
];
for (name, curve_fn) in parameterless {
let rms = compute_rms(lut, max_hdr, |x| curve_fn(x).min(1.0));
if rms < threshold {
let candidate = DetectedCurve {
name,
rms_error: rms,
params: DetectedParams::None,
};
if best.as_ref().is_none_or(|b| rms < b.rms_error) {
best = Some(candidate);
}
}
}
let l_max_candidates = [1.0, 2.0, 4.0, 8.0, 16.0, max_hdr, max_hdr * 2.0];
for &l_max in &l_max_candidates {
let rms = compute_rms(lut, max_hdr, |x| reinhard_extended(x, l_max).min(1.0));
if rms < threshold {
let candidate = DetectedCurve {
name: "ExtendedReinhard",
rms_error: rms,
params: DetectedParams::ExtendedReinhard { l_max },
};
if best.as_ref().is_none_or(|b| rms < b.rms_error) {
best = Some(candidate);
}
}
}
let content_nits = [1000.0, 2000.0, 4000.0, 10000.0];
let display_nits = [100.0, 203.0, 400.0, 1000.0];
for &cn in &content_nits {
for &dn in &display_nits {
if cn <= dn {
continue;
}
let tm = Bt2408Tonemapper::new(cn, dn);
let rms = compute_rms(lut, max_hdr, |x| {
let nits = x * cn;
tm.tonemap_nits(nits) / dn
});
if rms < threshold {
let candidate = DetectedCurve {
name: "Bt2408",
rms_error: rms,
params: DetectedParams::Bt2408 {
content_nits: cn,
display_nits: dn,
},
};
if best.as_ref().is_none_or(|b| rms < b.rms_error) {
best = Some(candidate);
}
}
}
}
best
}
fn compute_rms(lut: &[f32], max_hdr: f32, curve: impl Fn(f32) -> f32) -> f32 {
let n = lut.len();
let mut sum_sq = 0.0_f64;
let mut count = 0_usize;
for (i, &lut_val) in lut.iter().enumerate() {
let x = (i as f32 / (n - 1) as f32) * max_hdr;
let ref_val = curve(x);
let diff = (lut_val - ref_val) as f64;
sum_sq += diff * diff;
count += 1;
}
if count == 0 {
return f32::MAX;
}
libm::sqrt(sum_sq / count as f64) as f32
}
#[cfg(test)]
mod tests {
use super::*;
use alloc::vec::Vec;
fn generate_lut(max_hdr: f32, curve: impl Fn(f32) -> f32) -> Vec<f32> {
let n = 4096;
(0..n)
.map(|i| {
let x = (i as f32 / (n - 1) as f32) * max_hdr;
curve(x)
})
.collect()
}
#[test]
fn detect_reinhard() {
let lut = generate_lut(4.0, reinhard_simple);
let result = detect_standard(&lut, 4.0, 0.01);
assert!(result.is_some(), "should detect reinhard");
let r = result.unwrap();
assert_eq!(r.name, "Reinhard");
assert!(r.rms_error < 1e-5, "rms: {}", r.rms_error);
}
#[test]
fn detect_narkowicz() {
let lut = generate_lut(4.0, filmic_narkowicz);
let result = detect_standard(&lut, 4.0, 0.01);
assert!(result.is_some(), "should detect narkowicz");
assert_eq!(result.unwrap().name, "Narkowicz");
}
#[test]
fn detect_hable() {
let lut = generate_lut(4.0, |x| hable_filmic(x).min(1.0));
let result = detect_standard(&lut, 4.0, 0.01);
assert!(result.is_some(), "should detect hable");
assert_eq!(result.unwrap().name, "HableFilmic");
}
#[test]
fn detect_extended_reinhard() {
let l_max = 4.0;
let lut = generate_lut(4.0, |x| reinhard_extended(x, l_max));
let result = detect_standard(&lut, 4.0, 0.01);
assert!(result.is_some(), "should detect extended reinhard");
let r = result.unwrap();
assert_eq!(r.name, "ExtendedReinhard");
if let DetectedParams::ExtendedReinhard { l_max: detected } = r.params {
assert!(
(detected - l_max).abs() < 0.1,
"l_max: expected {l_max}, got {detected}"
);
}
}
#[test]
fn no_match_on_random_lut() {
let lut: Vec<f32> = (0..4096)
.map(|i| (i as f32 / 4095.0) * 0.3 + 0.2) .collect();
let result = detect_standard(&lut, 4.0, 0.01);
assert!(result.is_none(), "random LUT should not match");
}
}