use alloc::vec::Vec;
use core::array;
use palette::{LinSrgb, Oklab, Srgb, cast};
use wide::{CmpEq as _, CmpGe as _, CmpLe as _, f32x8, u32x8};
#[inline]
fn cbrt(x: f32x8) -> f32x8 {
const MU: f64 = 0.049593534765;
const BIAS: f64 = (f32::MAX_EXP - 1) as f64;
const EXP_SHIFT: f64 = (1 << (f32::MANTISSA_DIGITS - 1)) as f64;
#[allow(clippy::cast_possible_truncation, clippy::cast_sign_loss)]
const C: u32 = ((BIAS - BIAS / 3.0 - MU * 2.0 / 3.0) * EXP_SHIFT) as u32;
debug_assert!((x.simd_eq(0.0) | (x.simd_ge(f32::MIN_POSITIVE) & x.simd_le(1.0))).all());
let x_bits: u32x8 = bytemuck::cast(x);
let sign = x_bits & u32x8::splat(0x80000000);
let mut without_sign = x_bits & u32x8::splat(0x7fffffff);
for x in without_sign.as_mut_array() {
*x /= 3;
}
let without_sign = without_sign + u32x8::splat(C);
let x_bits = sign | without_sign;
let a: f32x8 = bytemuck::cast(x_bits);
let a3 = a * a * a;
let a = a * (x + x + a3) / (x + a3 + a3);
let a3 = a * a * a;
let a = a * (x + x + a3) / (x + a3 + a3);
x.simd_eq(0.0).blend(x, a)
}
#[allow(clippy::excessive_precision, clippy::inline_always)]
#[inline(always)] fn lin_srgb_to_oklab(lin_srgb: LinSrgb<f32x8>) -> Oklab<f32x8> {
let (r, g, b) = lin_srgb.into_components();
let l = 0.4122214708 * r + 0.5363325363 * g + 0.0514459929 * b;
let m = 0.2119034982 * r + 0.6806995451 * g + 0.1073969566 * b;
let s = 0.0883024619 * r + 0.2817188376 * g + 0.6299787005 * b;
let l_ = cbrt(l);
let m_ = cbrt(m);
let s_ = cbrt(s);
let l = 0.2104542553 * l_ + 0.7936177850 * m_ - 0.0040720468 * s_;
let a = 1.9779984951 * l_ + -2.4285922050 * m_ + 0.4505937099 * s_;
let b = 0.0259040371 * l_ + 0.7827717662 * m_ - 0.8086757660 * s_;
Oklab { l, a, b }
}
#[allow(clippy::inline_always)]
#[inline(always)]
fn lin_srgb_to_oklab_arr(lin_srgb: LinSrgb<f32x8>) -> [Oklab; 8] {
let Oklab { l, a, b } = lin_srgb_to_oklab(lin_srgb);
array::from_fn(|i| Oklab {
l: l.as_array()[i],
a: a.as_array()[i],
b: b.as_array()[i],
})
}
pub fn srgb8_to_oklab(input: &[Srgb<u8>]) -> Vec<Oklab> {
let mut output = bytemuck::zeroed_vec(input.len());
let (out_chunks, out_remainder) = output.as_chunks_mut::<8>();
let (in_chunks, in_remainder) = input.as_chunks::<8>();
for (chunk, output) in in_chunks.iter().zip(out_chunks) {
let mut arr = [[0.0; 8]; 3];
for (i, srgb) in chunk.iter().enumerate() {
let lin_srgb = cast::into_array(srgb.into_linear());
for (arr, c) in arr.iter_mut().zip(lin_srgb) {
arr[i] = c;
}
}
*output = lin_srgb_to_oklab_arr(cast::from_array(arr.map(f32x8::new)));
}
if !in_remainder.is_empty() {
let mut arr = [[0.0; 8]; 3];
for (i, srgb) in in_remainder.iter().enumerate() {
let lin_srgb = cast::into_array(srgb.into_linear());
for (arr, c) in arr.iter_mut().zip(lin_srgb) {
arr[i] = c;
}
}
let oklab = lin_srgb_to_oklab_arr(cast::from_array(arr.map(f32x8::new)));
out_remainder.copy_from_slice(&oklab[..in_remainder.len()]);
}
output
}
#[allow(clippy::excessive_precision, clippy::inline_always)]
#[inline(always)]
fn oklab_to_lin_srgb(oklab: Oklab<f32x8>) -> LinSrgb<f32x8> {
let Oklab { l, a, b } = oklab;
let l_ = l + 0.3963377774 * a + 0.2158037573 * b;
let m_ = l - 0.1055613458 * a - 0.0638541728 * b;
let s_ = l - 0.0894841775 * a - 1.2914855480 * b;
let l = l_ * l_ * l_;
let m = m_ * m_ * m_;
let s = s_ * s_ * s_;
let r = 4.0767416621 * l - 3.3077115913 * m + 0.2309699292 * s;
let g = -1.2684380046 * l + 2.6097574011 * m - 0.3413193965 * s;
let b = -0.0041960863 * l - 0.7034186147 * m + 1.7076147010 * s;
LinSrgb::new(r, g, b)
}
#[allow(clippy::excessive_precision, clippy::inline_always)]
#[inline(always)] fn oklab_arr_to_lin_srgb(oklab: [Oklab; 8]) -> LinSrgb<f32x8> {
let oklab = cast::into_array_array(oklab);
let oklab: [[f32; 8]; 3] = core::array::from_fn(|i| oklab.map(|x| x[i]));
let [l, a, b] = oklab.map(f32x8::new);
oklab_to_lin_srgb(Oklab { l, a, b })
}
pub fn oklab_to_srgb8(input: &[Oklab]) -> Vec<Srgb<u8>> {
let mut output = bytemuck::zeroed_vec(input.len());
let (out_chunks, out_remainder) = output.as_chunks_mut::<8>();
let (in_chunks, in_remainder) = input.as_chunks::<8>();
for (&chunk, output) in in_chunks.iter().zip(out_chunks) {
let lin_srgb = oklab_arr_to_lin_srgb(chunk);
for (i, output) in output.iter_mut().enumerate() {
*output = LinSrgb::new(
lin_srgb.red.as_array()[i],
lin_srgb.green.as_array()[i],
lin_srgb.blue.as_array()[i],
)
.into_encoding();
}
}
if !in_remainder.is_empty() {
let mut oklab = [Oklab::new(0.0, 0.0, 0.0); 8];
oklab[..in_remainder.len()].copy_from_slice(in_remainder);
let lin_srgb = oklab_arr_to_lin_srgb(oklab);
for (i, output) in out_remainder.iter_mut().enumerate() {
*output = LinSrgb::new(
lin_srgb.red.as_array()[i],
lin_srgb.green.as_array()[i],
lin_srgb.blue.as_array()[i],
)
.into_encoding();
}
}
output
}
#[cfg(feature = "threads")]
mod parallel {
use super::{lin_srgb_to_oklab_arr, oklab_arr_to_lin_srgb};
use core::array;
use palette::{LinSrgb, Oklab, Srgb, cast};
use rayon::prelude::*;
use wide::f32x8;
pub fn srgb8_to_oklab_par(input: &[Srgb<u8>]) -> Vec<Oklab> {
let mut output = Vec::<[Oklab; 8]>::with_capacity(input.len().div_ceil(8));
let (chunks, remainder) = input.as_chunks::<8>();
chunks
.par_iter()
.map(|chunk| {
let mut arr = [[0.0; 8]; 3];
for (i, srgb) in chunk.iter().enumerate() {
let lin_srgb = cast::into_array(srgb.into_linear());
for (arr, c) in arr.iter_mut().zip(lin_srgb) {
arr[i] = c;
}
}
lin_srgb_to_oklab_arr(cast::from_array(arr.map(f32x8::new)))
})
.collect_into_vec(&mut output);
let mut output = output.into_flattened();
if !remainder.is_empty() {
let mut arr = [[0.0; 8]; 3];
for (i, srgb) in remainder.iter().enumerate() {
let lin_srgb = cast::into_array(srgb.into_linear());
for (arr, c) in arr.iter_mut().zip(lin_srgb) {
arr[i] = c;
}
}
let oklab = lin_srgb_to_oklab_arr(cast::from_array(arr.map(f32x8::new)));
output.extend_from_slice(&oklab[..remainder.len()]);
}
output
}
pub fn oklab_to_srgb8_par(input: &[Oklab]) -> Vec<Srgb<u8>> {
let mut output = Vec::<[Srgb<u8>; 8]>::with_capacity(input.len().div_ceil(8));
let (chunks, remainder) = input.as_chunks::<8>();
chunks
.par_iter()
.with_min_len(2)
.map(|&chunk| {
let lin_srgb = oklab_arr_to_lin_srgb(chunk);
array::from_fn(|i| {
LinSrgb::new(
lin_srgb.red.as_array()[i],
lin_srgb.green.as_array()[i],
lin_srgb.blue.as_array()[i],
)
.into_encoding()
})
})
.collect_into_vec(&mut output);
let mut output = output.into_flattened();
if !remainder.is_empty() {
let mut oklab = [Oklab::new(0.0, 0.0, 0.0); 8];
oklab[..remainder.len()].copy_from_slice(remainder);
let lin_srgb = oklab_arr_to_lin_srgb(oklab);
let srgb: [Srgb<u8>; 8] = array::from_fn(|i| {
LinSrgb::new(
lin_srgb.red.as_array()[i],
lin_srgb.green.as_array()[i],
lin_srgb.blue.as_array()[i],
)
.into_encoding()
});
output.extend(&srgb[..remainder.len()]);
}
output
}
}
#[cfg(feature = "threads")]
pub use parallel::*;
#[cfg(test)]
mod tests {
use super::*;
use core::array;
use palette::{FromColor as _, LinSrgb, Oklab, Srgb, cast};
use wide::f32x8;
#[test]
#[ignore = "takes a long time"]
fn lin_srgb_to_oklab_oracle() {
for r in 0..=u8::MAX {
for g in 0..=u8::MAX {
for b in (0..=u8::MAX).step_by(8) {
#[allow(clippy::cast_possible_truncation)]
let lin_srgb: [_; 8] =
array::from_fn(|i| Srgb::new(r, g, b + i as u8).into_linear());
let actual = lin_srgb_to_oklab_arr(LinSrgb::new(
f32x8::new(lin_srgb.map(|x| x.red)),
f32x8::new(lin_srgb.map(|x| x.green)),
f32x8::new(lin_srgb.map(|x| x.blue)),
));
let expected = lin_srgb.map(Oklab::from_color);
for (actual, expected) in actual.into_iter().zip(expected) {
let [l, a, b] = cast::into_array(actual - expected).map(f32::abs);
assert!(l < 0.000001, "difference in l = {l} with ({r}, {g}, {b})");
assert!(a < 0.000001, "difference in a = {a} with ({r}, {g}, {b})");
assert!(b < 0.000001, "difference in b = {b} with ({r}, {g}, {b})");
}
}
}
}
}
#[test]
#[ignore = "takes a long time"]
fn srgb8_to_oklab_roundtrip() {
#[allow(clippy::cast_possible_truncation)]
let mut srgb: [_; 256 * 256] =
array::from_fn(|i| Srgb::new(0, (i / 256) as u8, (i % 256) as u8));
let srgb = &mut srgb;
for r in 0..=u8::MAX {
for srgb in &mut *srgb {
srgb.red = r;
}
let oklab = srgb8_to_oklab(srgb);
let actual = oklab_to_srgb8(&oklab);
for (actual, expected) in actual.iter().zip(&*srgb) {
assert_eq!(actual, expected);
}
}
}
}