use crate::{
Color,
generated::{COLORS, LABS_A, LABS_B, LABS_C, LABS_L},
};
pub(crate) mod scalar;
pub(crate) mod ciede2000;
#[cfg(feature = "lut")]
pub(crate) mod ciede2000_lut;
pub(crate) mod cie94;
#[cfg(all(target_arch = "aarch64", target_feature = "neon"))]
pub(crate) mod cie94_aarch64_neon;
#[cfg(target_arch = "x86_64")]
pub(crate) mod cie94_x86_sse41;
#[cfg(target_arch = "x86_64")]
pub(crate) mod cie94_x86_avx2;
#[cfg(target_arch = "x86_64")]
pub(crate) mod cie94_x86_avx512;
#[cfg(all(target_arch = "wasm32", target_feature = "simd128"))]
pub(crate) mod cie94_wasm_simd128;
#[cfg(all(target_arch = "aarch64", target_feature = "neon"))]
pub(crate) mod aarch64_neon;
#[cfg(target_arch = "x86_64")]
pub(crate) mod x86_sse41;
#[cfg(target_arch = "x86_64")]
pub(crate) mod x86_avx2;
#[cfg(target_arch = "x86_64")]
pub(crate) mod x86_avx512;
#[cfg(all(target_arch = "wasm32", target_feature = "simd128"))]
pub(crate) mod wasm_simd128;
#[allow(unsafe_code)]
#[allow(unreachable_code)]
#[inline]
pub(crate) fn nearest_idx(query: [f32; 3]) -> usize {
#[cfg(all(
target_arch = "aarch64",
target_feature = "neon",
not(colorthief_force_scalar)
))]
{
return aarch64_neon::nearest_idx(query);
}
#[cfg(all(
target_arch = "wasm32",
target_feature = "simd128",
not(colorthief_force_scalar)
))]
{
return wasm_simd128::nearest_idx(query);
}
#[cfg(all(target_arch = "x86_64", feature = "std", not(colorthief_force_scalar)))]
{
if !cfg!(colorthief_disable_avx512) && std::is_x86_feature_detected!("avx512f") {
return unsafe { x86_avx512::nearest_idx(query) };
}
if !cfg!(colorthief_disable_avx2) && std::is_x86_feature_detected!("avx2") {
return unsafe { x86_avx2::nearest_idx(query) };
}
if std::is_x86_feature_detected!("sse4.1") {
return unsafe { x86_sse41::nearest_idx(query) };
}
}
scalar::nearest_idx(query)
}
#[inline]
pub(crate) fn nearest(query: [f32; 3]) -> &'static Color {
COLORS[nearest_idx(query)]
}
#[cfg(feature = "lut")]
#[inline]
pub(crate) fn nearest_ciede2000(rgb: [u8; 3]) -> &'static Color {
let query = crate::rgb_to_lab(rgb);
COLORS[ciede2000_lut::nearest_idx(rgb, query)]
}
#[cfg(not(feature = "lut"))]
#[inline]
pub(crate) fn nearest_ciede2000(rgb: [u8; 3]) -> &'static Color {
let query = crate::rgb_to_lab(rgb);
COLORS[ciede2000::nearest_idx(query)]
}
#[allow(unsafe_code)]
#[allow(unreachable_code)]
#[inline]
pub(crate) fn nearest_cie94(query: [f32; 3]) -> &'static Color {
#[cfg(all(
target_arch = "aarch64",
target_feature = "neon",
not(colorthief_force_scalar)
))]
{
return COLORS[cie94_aarch64_neon::nearest_idx(query)];
}
#[cfg(all(
target_arch = "wasm32",
target_feature = "simd128",
not(colorthief_force_scalar)
))]
{
return COLORS[cie94_wasm_simd128::nearest_idx(query)];
}
#[cfg(all(target_arch = "x86_64", feature = "std", not(colorthief_force_scalar)))]
{
if !cfg!(colorthief_disable_avx512) && std::is_x86_feature_detected!("avx512f") {
return COLORS[unsafe { cie94_x86_avx512::nearest_idx(query) }];
}
if !cfg!(colorthief_disable_avx2) && std::is_x86_feature_detected!("avx2") {
return COLORS[unsafe { cie94_x86_avx2::nearest_idx(query) }];
}
if std::is_x86_feature_detected!("sse4.1") {
return COLORS[unsafe { cie94_x86_sse41::nearest_idx(query) }];
}
}
COLORS[cie94::nearest_idx(query)]
}
#[cfg(test)]
#[allow(unsafe_code)]
mod tests {
use super::*;
#[cfg(feature = "std")]
#[allow(dead_code)]
fn parity_grid() -> impl Iterator<Item = [u8; 3]> {
(0..256u32).step_by(16).flat_map(move |r| {
(0..256u32).step_by(16).flat_map(move |g| {
(0..256u32)
.step_by(16)
.map(move |b| [r as u8, g as u8, b as u8])
})
})
}
#[test]
fn soa_lab_arrays_align_with_aos_colors() {
assert_eq!(LABS_L.len(), COLORS.len());
assert_eq!(LABS_A.len(), COLORS.len());
assert_eq!(LABS_B.len(), COLORS.len());
for (i, c) in COLORS.iter().enumerate() {
let lab = c.lab();
assert_eq!(LABS_L[i], lab[0], "L mismatch at index {i}");
assert_eq!(LABS_A[i], lab[1], "a mismatch at index {i}");
assert_eq!(LABS_B[i], lab[2], "b mismatch at index {i}");
}
}
#[test]
#[cfg_attr(miri, ignore = "4913-query × 949-entry grid is too slow under miri")]
#[cfg(all(target_arch = "aarch64", target_feature = "neon", feature = "std"))]
fn neon_and_scalar_agree_across_grid() {
let mut mismatches = Vec::new();
for rgb in parity_grid() {
let query = crate::rgb_to_lab(rgb);
let s = scalar::nearest_idx(query);
let n = aarch64_neon::nearest_idx(query);
if s != n {
mismatches.push((rgb, s, n));
}
}
assert!(
mismatches.is_empty(),
"{} scalar/NEON mismatches across the 17³ grid; first few: {:?}",
mismatches.len(),
&mismatches[..mismatches.len().min(5)]
);
}
#[test]
#[cfg_attr(miri, ignore = "4913-query × 949-entry grid is too slow under miri")]
#[cfg(all(target_arch = "x86_64", feature = "std"))]
fn sse41_and_scalar_agree_across_grid() {
if !std::is_x86_feature_detected!("sse4.1") {
eprintln!("skipping: SSE4.1 not detected on this host");
return;
}
let mut mismatches = Vec::new();
for rgb in parity_grid() {
let query = crate::rgb_to_lab(rgb);
let s = scalar::nearest_idx(query);
let v = unsafe { x86_sse41::nearest_idx(query) };
if s != v {
mismatches.push((rgb, s, v));
}
}
assert!(
mismatches.is_empty(),
"{} scalar/SSE4.1 mismatches; first few: {:?}",
mismatches.len(),
&mismatches[..mismatches.len().min(5)]
);
}
#[test]
#[cfg_attr(miri, ignore = "4913-query × 949-entry grid is too slow under miri")]
#[cfg(all(target_arch = "x86_64", feature = "std"))]
fn avx512_and_scalar_agree_across_grid() {
if !std::is_x86_feature_detected!("avx512f") {
eprintln!("skipping: AVX-512F not detected on this host");
return;
}
let mut mismatches = Vec::new();
for rgb in parity_grid() {
let query = crate::rgb_to_lab(rgb);
let s = scalar::nearest_idx(query);
let v = unsafe { x86_avx512::nearest_idx(query) };
if s != v {
mismatches.push((rgb, s, v));
}
}
assert!(
mismatches.is_empty(),
"{} scalar/AVX-512F mismatches; first few: {:?}",
mismatches.len(),
&mismatches[..mismatches.len().min(5)]
);
}
#[test]
#[cfg_attr(miri, ignore = "4913-query × 949-entry grid is too slow under miri")]
#[cfg(all(target_arch = "x86_64", feature = "std"))]
fn cie94_avx512_and_scalar_agree_across_grid() {
if !std::is_x86_feature_detected!("avx512f") {
eprintln!("skipping: AVX-512F not detected on this host");
return;
}
let mut mismatches = Vec::new();
for rgb in parity_grid() {
let query = crate::rgb_to_lab(rgb);
let s = cie94::nearest_idx(query);
let v = unsafe { cie94_x86_avx512::nearest_idx(query) };
if s != v {
mismatches.push((rgb, s, v));
}
}
assert!(
mismatches.is_empty(),
"{} CIE94 scalar/AVX-512F mismatches; first few: {:?}",
mismatches.len(),
&mismatches[..mismatches.len().min(5)]
);
}
#[test]
#[cfg_attr(miri, ignore = "4913-query × 949-entry grid is too slow under miri")]
#[cfg(all(target_arch = "x86_64", feature = "std"))]
fn avx2_and_scalar_agree_across_grid() {
if !std::is_x86_feature_detected!("avx2") {
eprintln!("skipping: AVX2 not detected on this host");
return;
}
let mut mismatches = Vec::new();
for rgb in parity_grid() {
let query = crate::rgb_to_lab(rgb);
let s = scalar::nearest_idx(query);
let v = unsafe { x86_avx2::nearest_idx(query) };
if s != v {
mismatches.push((rgb, s, v));
}
}
assert!(
mismatches.is_empty(),
"{} scalar/AVX2 mismatches; first few: {:?}",
mismatches.len(),
&mismatches[..mismatches.len().min(5)]
);
}
#[test]
#[cfg_attr(miri, ignore = "4913-query × 949-entry grid is too slow under miri")]
#[cfg(all(target_arch = "aarch64", target_feature = "neon", feature = "std"))]
fn cie94_neon_and_scalar_agree_across_grid() {
let mut mismatches = Vec::new();
for rgb in parity_grid() {
let query = crate::rgb_to_lab(rgb);
let s = cie94::nearest_idx(query);
let n = cie94_aarch64_neon::nearest_idx(query);
if s != n {
mismatches.push((rgb, s, n));
}
}
assert!(
mismatches.is_empty(),
"{} CIE94 scalar/NEON mismatches; first few: {:?}",
mismatches.len(),
&mismatches[..mismatches.len().min(5)]
);
}
#[test]
#[cfg_attr(miri, ignore = "4913-query × 949-entry grid is too slow under miri")]
#[cfg(all(target_arch = "x86_64", feature = "std"))]
fn cie94_sse41_and_scalar_agree_across_grid() {
if !std::is_x86_feature_detected!("sse4.1") {
eprintln!("skipping: SSE4.1 not detected");
return;
}
let mut mismatches = Vec::new();
for rgb in parity_grid() {
let query = crate::rgb_to_lab(rgb);
let s = cie94::nearest_idx(query);
let v = unsafe { cie94_x86_sse41::nearest_idx(query) };
if s != v {
mismatches.push((rgb, s, v));
}
}
assert!(
mismatches.is_empty(),
"{} CIE94 scalar/SSE4.1 mismatches; first few: {:?}",
mismatches.len(),
&mismatches[..mismatches.len().min(5)]
);
}
#[test]
#[cfg_attr(miri, ignore = "4913-query × 949-entry grid is too slow under miri")]
#[cfg(all(target_arch = "x86_64", feature = "std"))]
fn cie94_avx2_and_scalar_agree_across_grid() {
if !std::is_x86_feature_detected!("avx2") {
eprintln!("skipping: AVX2 not detected");
return;
}
let mut mismatches = Vec::new();
for rgb in parity_grid() {
let query = crate::rgb_to_lab(rgb);
let s = cie94::nearest_idx(query);
let v = unsafe { cie94_x86_avx2::nearest_idx(query) };
if s != v {
mismatches.push((rgb, s, v));
}
}
assert!(
mismatches.is_empty(),
"{} CIE94 scalar/AVX2 mismatches; first few: {:?}",
mismatches.len(),
&mismatches[..mismatches.len().min(5)]
);
}
#[test]
#[cfg_attr(miri, ignore = "4913-query × 949-entry grid is too slow under miri")]
#[cfg(all(target_arch = "wasm32", target_feature = "simd128", feature = "std"))]
fn cie94_wasm_simd128_and_scalar_agree_across_grid() {
let mut mismatches = Vec::new();
for rgb in parity_grid() {
let query = crate::rgb_to_lab(rgb);
let s = cie94::nearest_idx(query);
let v = cie94_wasm_simd128::nearest_idx(query);
if s != v {
mismatches.push((rgb, s, v));
}
}
assert!(
mismatches.is_empty(),
"{} CIE94 scalar/WASM SIMD128 mismatches; first few: {:?}",
mismatches.len(),
&mismatches[..mismatches.len().min(5)]
);
}
#[test]
#[cfg_attr(miri, ignore = "4913-query × 949-entry grid is too slow under miri")]
#[cfg(all(target_arch = "wasm32", target_feature = "simd128", feature = "std"))]
fn wasm_simd128_and_scalar_agree_across_grid() {
let mut mismatches = Vec::new();
for rgb in parity_grid() {
let query = crate::rgb_to_lab(rgb);
let s = scalar::nearest_idx(query);
let v = wasm_simd128::nearest_idx(query);
if s != v {
mismatches.push((rgb, s, v));
}
}
assert!(
mismatches.is_empty(),
"{} scalar/WASM SIMD128 mismatches; first few: {:?}",
mismatches.len(),
&mismatches[..mismatches.len().min(5)]
);
}
}