use crate::{GpuDevice, Result};
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
#[allow(non_camel_case_types)]
pub enum ColorSpace {
RGB,
YUV_BT601,
YUV_BT709,
YUV_BT2020,
HSV,
HSL,
Lab,
LinearRGB,
SRGB,
}
impl ColorSpace {
#[must_use]
pub fn is_yuv(self) -> bool {
matches!(self, Self::YUV_BT601 | Self::YUV_BT709 | Self::YUV_BT2020)
}
#[must_use]
pub fn is_rgb(self) -> bool {
matches!(self, Self::RGB | Self::LinearRGB | Self::SRGB)
}
#[must_use]
pub fn name(self) -> &'static str {
match self {
Self::RGB => "RGB",
Self::YUV_BT601 => "YUV (BT.601)",
Self::YUV_BT709 => "YUV (BT.709)",
Self::YUV_BT2020 => "YUV (BT.2020)",
Self::HSV => "HSV",
Self::HSL => "HSL",
Self::Lab => "CIE Lab",
Self::LinearRGB => "Linear RGB",
Self::SRGB => "sRGB",
}
}
}
impl From<ColorSpace> for crate::ops::ColorSpace {
fn from(space: ColorSpace) -> Self {
match space {
ColorSpace::YUV_BT601 | ColorSpace::RGB => Self::BT601,
ColorSpace::YUV_BT709 => Self::BT709,
ColorSpace::YUV_BT2020 => Self::BT2020,
_ => Self::BT601, }
}
}
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
pub enum ColorConversion {
RGBtoYUV,
YUVtoRGB,
RGBtoHSV,
HSVtoRGB,
RGBtoLab,
LabtoRGB,
SRGBtoLinear,
LinearToSRGB,
}
pub struct ColorConversionKernel {
conversion: ColorConversion,
color_space: ColorSpace,
}
impl ColorConversionKernel {
#[must_use]
pub fn new(conversion: ColorConversion, color_space: ColorSpace) -> Self {
Self {
conversion,
color_space,
}
}
#[must_use]
pub fn rgb_to_yuv(color_space: ColorSpace) -> Self {
Self::new(ColorConversion::RGBtoYUV, color_space)
}
#[must_use]
pub fn yuv_to_rgb(color_space: ColorSpace) -> Self {
Self::new(ColorConversion::YUVtoRGB, color_space)
}
pub fn execute(
&self,
device: &GpuDevice,
input: &[u8],
output: &mut [u8],
width: u32,
height: u32,
) -> Result<()> {
match self.conversion {
ColorConversion::RGBtoYUV => crate::ops::ColorSpaceConversion::rgb_to_yuv(
device,
input,
output,
width,
height,
self.color_space.into(),
),
ColorConversion::YUVtoRGB => crate::ops::ColorSpaceConversion::yuv_to_rgb(
device,
input,
output,
width,
height,
self.color_space.into(),
),
ColorConversion::RGBtoHSV => {
let result = crate::ops::ColorSpaceConversion::rgb_to_hsv(input, width, height);
let copy_len = result.len().min(output.len());
output[..copy_len].copy_from_slice(&result[..copy_len]);
Ok(())
}
ColorConversion::HSVtoRGB => {
let result = crate::ops::ColorSpaceConversion::hsv_to_rgb(input, width, height);
let copy_len = result.len().min(output.len());
output[..copy_len].copy_from_slice(&result[..copy_len]);
Ok(())
}
ColorConversion::RGBtoLab => {
let result = crate::ops::ColorSpaceConversion::rgb_to_lab(input, width, height);
let copy_len = result.len().min(output.len());
output[..copy_len].copy_from_slice(&result[..copy_len]);
Ok(())
}
ColorConversion::LabtoRGB => {
let result = crate::ops::ColorSpaceConversion::lab_to_rgb(input, width, height);
let copy_len = result.len().min(output.len());
output[..copy_len].copy_from_slice(&result[..copy_len]);
Ok(())
}
ColorConversion::SRGBtoLinear => {
let result = crate::ops::ColorSpaceConversion::srgb_to_linear(input, width, height);
let copy_len = result.len().min(output.len());
output[..copy_len].copy_from_slice(&result[..copy_len]);
Ok(())
}
ColorConversion::LinearToSRGB => {
let result = crate::ops::ColorSpaceConversion::linear_to_srgb(input, width, height);
let copy_len = result.len().min(output.len());
output[..copy_len].copy_from_slice(&result[..copy_len]);
Ok(())
}
}
}
#[must_use]
pub fn conversion(&self) -> ColorConversion {
self.conversion
}
#[must_use]
pub fn color_space(&self) -> ColorSpace {
self.color_space
}
#[must_use]
pub fn output_size(width: u32, height: u32, channels: u32) -> usize {
(width * height * channels) as usize
}
#[must_use]
pub fn estimate_flops(width: u32, height: u32, conversion: ColorConversion) -> u64 {
let pixels = u64::from(width) * u64::from(height);
match conversion {
ColorConversion::RGBtoYUV | ColorConversion::YUVtoRGB => {
pixels * 15
}
ColorConversion::RGBtoHSV | ColorConversion::HSVtoRGB => {
pixels * 20
}
ColorConversion::RGBtoLab | ColorConversion::LabtoRGB => {
pixels * 50
}
ColorConversion::SRGBtoLinear | ColorConversion::LinearToSRGB => {
pixels * 3 * 5
}
}
}
}
pub struct LutKernel {
lut_size: usize,
}
impl LutKernel {
#[must_use]
pub fn new(lut_size: usize) -> Self {
Self { lut_size }
}
#[must_use]
pub fn lut_size(&self) -> usize {
self.lut_size
}
#[allow(clippy::too_many_arguments)]
pub fn apply_1d(
&self,
_device: &GpuDevice,
input: &[u8],
output: &mut [u8],
lut: &[u8],
_width: u32,
_height: u32,
) -> Result<()> {
if self.lut_size == 0 || lut.is_empty() {
return Err(crate::GpuError::NotSupported(
"1D LUT size must be non-zero".to_string(),
));
}
let channels = lut.len() / self.lut_size;
if channels == 0 {
return Err(crate::GpuError::NotSupported(
"1D LUT must cover at least one channel".to_string(),
));
}
let lut_max = self.lut_size - 1;
let full_pixels = input.len() / channels;
for px in 0..full_pixels {
let base = px * channels;
for c in 0..channels {
let pixel_val = input[base + c] as usize;
let lut_idx = (pixel_val * lut_max + 127) / 255; let lut_idx = lut_idx.min(lut_max);
output[base + c] = lut[c * self.lut_size + lut_idx];
}
}
let tail_start = full_pixels * channels;
output[tail_start..input.len()].copy_from_slice(&input[tail_start..]);
Ok(())
}
#[allow(clippy::too_many_arguments)]
pub fn apply_3d(
&self,
_device: &GpuDevice,
input: &[u8],
output: &mut [u8],
lut: &[f32],
_width: u32,
_height: u32,
) -> Result<()> {
let n = self.lut_size;
if n == 0 {
return Err(crate::GpuError::NotSupported(
"3D LUT size must be non-zero".to_string(),
));
}
let expected_lut = n * n * n * 3;
if lut.len() < expected_lut {
return Err(crate::GpuError::NotSupported(format!(
"3D LUT too small: expected {expected_lut} entries, got {}",
lut.len()
)));
}
let pixel_stride = 3usize;
let full_pixels = input.len() / pixel_stride;
for px in 0..full_pixels {
let base = px * pixel_stride;
let r = f32::from(input[base]) / 255.0;
let g = f32::from(input[base + 1]) / 255.0;
let b = f32::from(input[base + 2]) / 255.0;
let nf = (n - 1) as f32;
let rx = r * nf;
let gy = g * nf;
let bz = b * nf;
let r0 = (rx.floor() as usize).min(n - 1);
let g0 = (gy.floor() as usize).min(n - 1);
let b0 = (bz.floor() as usize).min(n - 1);
let r1 = (r0 + 1).min(n - 1);
let g1 = (g0 + 1).min(n - 1);
let b1 = (b0 + 1).min(n - 1);
let fr = rx - r0 as f32;
let fg = gy - g0 as f32;
let fb = bz - b0 as f32;
let lut_val = |ri: usize, gi: usize, bi: usize, ch: usize| -> f32 {
lut[(ri * n * n + gi * n + bi) * 3 + ch]
};
for ch in 0..3 {
let c000 = lut_val(r0, g0, b0, ch);
let c100 = lut_val(r1, g0, b0, ch);
let c010 = lut_val(r0, g1, b0, ch);
let c110 = lut_val(r1, g1, b0, ch);
let c001 = lut_val(r0, g0, b1, ch);
let c101 = lut_val(r1, g0, b1, ch);
let c011 = lut_val(r0, g1, b1, ch);
let c111 = lut_val(r1, g1, b1, ch);
let c00 = c000 * (1.0 - fr) + c100 * fr;
let c01 = c001 * (1.0 - fr) + c101 * fr;
let c10 = c010 * (1.0 - fr) + c110 * fr;
let c11 = c011 * (1.0 - fr) + c111 * fr;
let c0 = c00 * (1.0 - fg) + c10 * fg;
let c1 = c01 * (1.0 - fg) + c11 * fg;
let val = c0 * (1.0 - fb) + c1 * fb;
output[base + ch] = (val.clamp(0.0, 1.0) * 255.0).round() as u8;
}
}
let tail_start = full_pixels * pixel_stride;
output[tail_start..input.len()].copy_from_slice(&input[tail_start..]);
Ok(())
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_color_space_properties() {
assert!(ColorSpace::YUV_BT601.is_yuv());
assert!(ColorSpace::YUV_BT709.is_yuv());
assert!(ColorSpace::YUV_BT2020.is_yuv());
assert!(!ColorSpace::RGB.is_yuv());
assert!(ColorSpace::RGB.is_rgb());
assert!(ColorSpace::LinearRGB.is_rgb());
assert!(ColorSpace::SRGB.is_rgb());
assert!(!ColorSpace::YUV_BT601.is_rgb());
}
#[test]
fn test_color_conversion_kernel() {
let kernel = ColorConversionKernel::rgb_to_yuv(ColorSpace::YUV_BT709);
assert_eq!(kernel.conversion(), ColorConversion::RGBtoYUV);
assert_eq!(kernel.color_space(), ColorSpace::YUV_BT709);
}
#[test]
fn test_flops_estimation() {
let flops = ColorConversionKernel::estimate_flops(1920, 1080, ColorConversion::RGBtoYUV);
assert!(flops > 0);
let flops_lab =
ColorConversionKernel::estimate_flops(1920, 1080, ColorConversion::RGBtoLab);
assert!(flops_lab > flops); }
fn identity_lut_1d(lut_size: usize, channels: usize) -> Vec<u8> {
let mut lut = vec![0u8; lut_size * channels];
for c in 0..channels {
for i in 0..lut_size {
lut[c * lut_size + i] = ((i * 255) / (lut_size - 1)) as u8;
}
}
lut
}
fn identity_lut_3d(n: usize) -> Vec<f32> {
let mut lut = vec![0.0f32; n * n * n * 3];
for ri in 0..n {
for gi in 0..n {
for bi in 0..n {
let base = (ri * n * n + gi * n + bi) * 3;
lut[base] = ri as f32 / (n - 1) as f32;
lut[base + 1] = gi as f32 / (n - 1) as f32;
lut[base + 2] = bi as f32 / (n - 1) as f32;
}
}
}
lut
}
#[test]
fn test_apply_1d_identity() {
let lut_size = 256usize;
let channels = 3usize;
let lut = identity_lut_1d(lut_size, channels);
let input: Vec<u8> = vec![0, 128, 255, 64, 192, 10];
let mut output = vec![0u8; input.len()];
let kernel = LutKernel::new(lut_size);
let lut_max = lut_size - 1;
let full_pixels = input.len() / channels;
for px in 0..full_pixels {
let base = px * channels;
for c in 0..channels {
let pixel_val = input[base + c] as usize;
let lut_idx = ((pixel_val * lut_max + 127) / 255).min(lut_max);
output[base + c] = lut[c * kernel.lut_size() + lut_idx];
}
}
for (i, (&inp, &out)) in input.iter().zip(output.iter()).enumerate() {
let diff = inp as i32 - out as i32;
assert!(diff.abs() <= 1, "pixel {i}: input={inp}, output={out}");
}
}
#[test]
fn test_apply_1d_invert() {
let lut_size = 256usize;
let _channels = 1usize;
let lut: Vec<u8> = (0..lut_size).map(|i| (255 - i) as u8).collect();
let input: Vec<u8> = vec![0, 64, 128, 192, 255];
let mut output = vec![0u8; input.len()];
let lut_max = lut_size - 1;
for (i, &v) in input.iter().enumerate() {
let lut_idx = ((v as usize * lut_max + 127) / 255).min(lut_max);
output[i] = lut[lut_idx];
}
assert_eq!(output[0], 255);
assert_eq!(output[4], 0);
}
#[test]
fn test_apply_3d_identity() {
let n = 17usize; let lut = identity_lut_3d(n);
let input: Vec<u8> = vec![0, 0, 0, 128, 64, 192, 255, 255, 255];
let mut output = vec![0u8; input.len()];
let nf = (n - 1) as f32;
let pixel_stride = 3usize;
let full_pixels = input.len() / pixel_stride;
for px in 0..full_pixels {
let base = px * pixel_stride;
let r = input[base] as f32 / 255.0;
let g = input[base + 1] as f32 / 255.0;
let b = input[base + 2] as f32 / 255.0;
let rx = r * nf;
let gy = g * nf;
let bz = b * nf;
let r0 = (rx.floor() as usize).min(n - 1);
let g0 = (gy.floor() as usize).min(n - 1);
let b0 = (bz.floor() as usize).min(n - 1);
let r1 = (r0 + 1).min(n - 1);
let g1 = (g0 + 1).min(n - 1);
let b1 = (b0 + 1).min(n - 1);
let fr = rx - r0 as f32;
let fg = gy - g0 as f32;
let fb = bz - b0 as f32;
for ch in 0..3 {
let lv = |ri: usize, gi: usize, bi: usize| -> f32 {
lut[(ri * n * n + gi * n + bi) * 3 + ch]
};
let c000 = lv(r0, g0, b0);
let c100 = lv(r1, g0, b0);
let c010 = lv(r0, g1, b0);
let c110 = lv(r1, g1, b0);
let c001 = lv(r0, g0, b1);
let c101 = lv(r1, g0, b1);
let c011 = lv(r0, g1, b1);
let c111 = lv(r1, g1, b1);
let c00 = c000 * (1.0 - fr) + c100 * fr;
let c01 = c001 * (1.0 - fr) + c101 * fr;
let c10 = c010 * (1.0 - fr) + c110 * fr;
let c11 = c011 * (1.0 - fr) + c111 * fr;
let c0 = c00 * (1.0 - fg) + c10 * fg;
let c1 = c01 * (1.0 - fg) + c11 * fg;
let val = c0 * (1.0 - fb) + c1 * fb;
output[base + ch] = (val.clamp(0.0, 1.0) * 255.0).round() as u8;
}
}
for (i, (&inp, &out)) in input.iter().zip(output.iter()).enumerate() {
let diff = inp as i32 - out as i32;
assert!(
diff.abs() <= 2,
"channel byte {i}: input={inp}, output={out}"
);
}
}
#[test]
fn test_apply_3d_black_white() {
let n = 2usize; let lut = identity_lut_3d(n);
let input: Vec<u8> = vec![0, 0, 0, 255, 255, 255];
let mut output = vec![0u8; 6];
let nf = (n - 1) as f32;
for px in 0..2usize {
let base = px * 3;
let r = input[base] as f32 / 255.0;
let g = input[base + 1] as f32 / 255.0;
let b = input[base + 2] as f32 / 255.0;
let rx = r * nf;
let gy = g * nf;
let bz = b * nf;
let r0 = (rx.floor() as usize).min(n - 1);
let g0 = (gy.floor() as usize).min(n - 1);
let b0 = (bz.floor() as usize).min(n - 1);
let r1 = (r0 + 1).min(n - 1);
let g1 = (g0 + 1).min(n - 1);
let b1 = (b0 + 1).min(n - 1);
let fr = rx - r0 as f32;
let fg = gy - g0 as f32;
let fb = bz - b0 as f32;
for ch in 0..3 {
let lv = |ri: usize, gi: usize, bi: usize| -> f32 {
lut[(ri * n * n + gi * n + bi) * 3 + ch]
};
let c000 = lv(r0, g0, b0);
let c100 = lv(r1, g0, b0);
let c010 = lv(r0, g1, b0);
let c110 = lv(r1, g1, b0);
let c001 = lv(r0, g0, b1);
let c101 = lv(r1, g0, b1);
let c011 = lv(r0, g1, b1);
let c111 = lv(r1, g1, b1);
let c00 = c000 * (1.0 - fr) + c100 * fr;
let c01 = c001 * (1.0 - fr) + c101 * fr;
let c10 = c010 * (1.0 - fr) + c110 * fr;
let c11 = c011 * (1.0 - fr) + c111 * fr;
let c0 = c00 * (1.0 - fg) + c10 * fg;
let c1 = c01 * (1.0 - fg) + c11 * fg;
let val = c0 * (1.0 - fb) + c1 * fb;
output[base + ch] = (val.clamp(0.0, 1.0) * 255.0).round() as u8;
}
}
assert_eq!(&output[0..3], &[0u8, 0, 0]);
assert_eq!(&output[3..6], &[255u8, 255, 255]);
}
}