#![allow(clippy::excessive_precision)]
use crate::ext::ColorPrimariesExt;
use crate::gamut::{GamutMatrix, mat3_mul};
use zenpixels::ColorPrimaries;
pub const LMS_FROM_XYZ: GamutMatrix = [
[0.8189330101, 0.3618667424, -0.1288597137],
[0.0329845436, 0.9293118715, 0.0361456387],
[0.0482003018, 0.2643662691, 0.6338517070],
];
pub const XYZ_FROM_LMS: GamutMatrix = [
[1.2270138511, -0.5577999807, 0.2812561490],
[-0.0405801784, 1.1122568696, -0.0716766787],
[-0.0763812845, -0.4214819784, 1.5861632204],
];
pub const OKLAB_FROM_LMS_CBRT: GamutMatrix = [
[0.2104542553, 0.7936177850, -0.0040720468],
[1.9779984951, -2.4285922050, 0.4505937099],
[0.0259040371, 0.7827717662, -0.8086757660],
];
pub const LMS_CBRT_FROM_OKLAB: GamutMatrix = [
[1.0, 0.3963377774, 0.2158037573],
[1.0, -0.1055613458, -0.0638541728],
[1.0, -0.0894841775, -1.2914855480],
];
pub fn rgb_to_lms_matrix(primaries: ColorPrimaries) -> Option<GamutMatrix> {
let to_xyz = primaries.to_xyz_matrix()?;
Some(mat3_mul(&LMS_FROM_XYZ, to_xyz))
}
pub fn lms_to_rgb_matrix(primaries: ColorPrimaries) -> Option<GamutMatrix> {
let from_xyz = primaries.from_xyz_matrix()?;
Some(mat3_mul(from_xyz, &XYZ_FROM_LMS))
}
pub fn fast_cbrt(x: f32) -> f32 {
if x == 0.0 {
return 0.0;
}
let sign = x.signum();
let x = x.abs();
let bits = x.to_bits();
let estimate = f32::from_bits((bits / 3) + (0x2a51_7d48));
let mut y = estimate;
y = (2.0 * y + x / (y * y)) / 3.0;
y = (2.0 * y + x / (y * y)) / 3.0;
sign * y
}
pub fn rgb_to_oklab(r: f32, g: f32, b: f32, m1: &GamutMatrix) -> [f32; 3] {
let l = m1[0][0] * r + m1[0][1] * g + m1[0][2] * b;
let m = m1[1][0] * r + m1[1][1] * g + m1[1][2] * b;
let s = m1[2][0] * r + m1[2][1] * g + m1[2][2] * b;
let l_ = fast_cbrt(l);
let m_ = fast_cbrt(m);
let s_ = fast_cbrt(s);
let ok_l = OKLAB_FROM_LMS_CBRT[0][0] * l_
+ OKLAB_FROM_LMS_CBRT[0][1] * m_
+ OKLAB_FROM_LMS_CBRT[0][2] * s_;
let ok_a = OKLAB_FROM_LMS_CBRT[1][0] * l_
+ OKLAB_FROM_LMS_CBRT[1][1] * m_
+ OKLAB_FROM_LMS_CBRT[1][2] * s_;
let ok_b = OKLAB_FROM_LMS_CBRT[2][0] * l_
+ OKLAB_FROM_LMS_CBRT[2][1] * m_
+ OKLAB_FROM_LMS_CBRT[2][2] * s_;
[ok_l, ok_a, ok_b]
}
pub fn oklab_to_rgb(l: f32, a: f32, b: f32, m1_inv: &GamutMatrix) -> [f32; 3] {
let l_ = LMS_CBRT_FROM_OKLAB[0][0] * l
+ LMS_CBRT_FROM_OKLAB[0][1] * a
+ LMS_CBRT_FROM_OKLAB[0][2] * b;
let m_ = LMS_CBRT_FROM_OKLAB[1][0] * l
+ LMS_CBRT_FROM_OKLAB[1][1] * a
+ LMS_CBRT_FROM_OKLAB[1][2] * b;
let s_ = LMS_CBRT_FROM_OKLAB[2][0] * l
+ LMS_CBRT_FROM_OKLAB[2][1] * a
+ LMS_CBRT_FROM_OKLAB[2][2] * b;
let lms_l = l_ * l_ * l_;
let lms_m = m_ * m_ * m_;
let lms_s = s_ * s_ * s_;
let r = m1_inv[0][0] * lms_l + m1_inv[0][1] * lms_m + m1_inv[0][2] * lms_s;
let g = m1_inv[1][0] * lms_l + m1_inv[1][1] * lms_m + m1_inv[1][2] * lms_s;
let b = m1_inv[2][0] * lms_l + m1_inv[2][1] * lms_m + m1_inv[2][2] * lms_s;
[r, g, b]
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn fast_cbrt_accuracy() {
let test_values: [f32; 10] = [0.0, 0.001, 0.01, 0.1, 0.5, 1.0, 2.0, 8.0, 27.0, 100.0];
for &x in &test_values {
let expected = x.cbrt();
let got = fast_cbrt(x);
let err = (got - expected).abs();
assert!(
err < 1e-5 || err / expected.max(1e-10) < 1e-5,
"fast_cbrt({x}) = {got}, expected {expected}, err = {err}"
);
}
}
#[test]
fn fast_cbrt_negative() {
let got = fast_cbrt(-8.0);
assert!((got - (-2.0)).abs() < 1e-5, "fast_cbrt(-8) = {got}");
}
#[test]
fn oklab_roundtrip_bt709() {
let m1 = rgb_to_lms_matrix(ColorPrimaries::Bt709).unwrap();
let m1_inv = lms_to_rgb_matrix(ColorPrimaries::Bt709).unwrap();
let test_colors = [
[0.5, 0.3, 0.8],
[1.0, 0.0, 0.0],
[0.0, 1.0, 0.0],
[0.0, 0.0, 1.0],
[0.18, 0.18, 0.18], ];
for rgb in &test_colors {
let [l, a, b] = rgb_to_oklab(rgb[0], rgb[1], rgb[2], &m1);
let [r2, g2, b2] = oklab_to_rgb(l, a, b, &m1_inv);
for c in 0..3 {
let err = (rgb[c] - [r2, g2, b2][c]).abs();
assert!(err < 1e-4, "roundtrip error for {rgb:?} channel {c}: {err}");
}
}
}
#[test]
fn oklab_white_point() {
let m1 = rgb_to_lms_matrix(ColorPrimaries::Bt709).unwrap();
let [l, a, b] = rgb_to_oklab(1.0, 1.0, 1.0, &m1);
assert!((l - 1.0).abs() < 5e-4, "white L should be ~1.0, got {l}");
assert!(a.abs() < 5e-4, "white a should be ~0.0, got {a}");
assert!(b.abs() < 5e-4, "white b should be ~0.0, got {b}");
}
#[test]
fn oklab_black_point() {
let m1 = rgb_to_lms_matrix(ColorPrimaries::Bt709).unwrap();
let [l, a, b] = rgb_to_oklab(0.0, 0.0, 0.0, &m1);
assert!(l.abs() < 1e-6, "black L should be ~0.0, got {l}");
assert!(a.abs() < 1e-6, "black a should be ~0.0, got {a}");
assert!(b.abs() < 1e-6, "black b should be ~0.0, got {b}");
}
#[test]
fn oklab_roundtrip_bt2020() {
let m1 = rgb_to_lms_matrix(ColorPrimaries::Bt2020).unwrap();
let m1_inv = lms_to_rgb_matrix(ColorPrimaries::Bt2020).unwrap();
let rgb = [0.4, 0.6, 0.2];
let [l, a, b] = rgb_to_oklab(rgb[0], rgb[1], rgb[2], &m1);
let [r2, g2, b2] = oklab_to_rgb(l, a, b, &m1_inv);
for c in 0..3 {
let err = (rgb[c] - [r2, g2, b2][c]).abs();
assert!(err < 1e-4, "BT.2020 roundtrip error channel {c}: {err}");
}
}
#[test]
fn combined_matrices_available() {
assert!(rgb_to_lms_matrix(ColorPrimaries::Bt709).is_some());
assert!(rgb_to_lms_matrix(ColorPrimaries::DisplayP3).is_some());
assert!(rgb_to_lms_matrix(ColorPrimaries::Bt2020).is_some());
assert!(rgb_to_lms_matrix(ColorPrimaries::Unknown).is_none());
assert!(lms_to_rgb_matrix(ColorPrimaries::Bt709).is_some());
assert!(lms_to_rgb_matrix(ColorPrimaries::Unknown).is_none());
}
}