Skip to main content

oximedia_gpu/kernels/
color.rs

1//! Color space conversion kernels
2
3use crate::{GpuDevice, Result};
4
5/// Color space standards
6#[derive(Debug, Clone, Copy, PartialEq, Eq)]
7#[allow(non_camel_case_types)]
8pub enum ColorSpace {
9    /// RGB color space
10    RGB,
11    /// YUV with BT.601 coefficients (SD video)
12    YUV_BT601,
13    /// YUV with BT.709 coefficients (HD video)
14    YUV_BT709,
15    /// YUV with BT.2020 coefficients (UHD video)
16    YUV_BT2020,
17    /// HSV color space
18    HSV,
19    /// HSL color space
20    HSL,
21    /// CIE Lab color space
22    Lab,
23    /// Linear RGB
24    LinearRGB,
25    /// sRGB
26    SRGB,
27}
28
29impl ColorSpace {
30    /// Check if this is a YUV color space
31    #[must_use]
32    pub fn is_yuv(self) -> bool {
33        matches!(self, Self::YUV_BT601 | Self::YUV_BT709 | Self::YUV_BT2020)
34    }
35
36    /// Check if this is an RGB color space
37    #[must_use]
38    pub fn is_rgb(self) -> bool {
39        matches!(self, Self::RGB | Self::LinearRGB | Self::SRGB)
40    }
41
42    /// Get the color space name
43    #[must_use]
44    pub fn name(self) -> &'static str {
45        match self {
46            Self::RGB => "RGB",
47            Self::YUV_BT601 => "YUV (BT.601)",
48            Self::YUV_BT709 => "YUV (BT.709)",
49            Self::YUV_BT2020 => "YUV (BT.2020)",
50            Self::HSV => "HSV",
51            Self::HSL => "HSL",
52            Self::Lab => "CIE Lab",
53            Self::LinearRGB => "Linear RGB",
54            Self::SRGB => "sRGB",
55        }
56    }
57}
58
59impl From<ColorSpace> for crate::ops::ColorSpace {
60    fn from(space: ColorSpace) -> Self {
61        match space {
62            ColorSpace::YUV_BT601 | ColorSpace::RGB => Self::BT601,
63            ColorSpace::YUV_BT709 => Self::BT709,
64            ColorSpace::YUV_BT2020 => Self::BT2020,
65            _ => Self::BT601, // Default fallback
66        }
67    }
68}
69
70/// Color conversion operation type
71#[derive(Debug, Clone, Copy, PartialEq, Eq)]
72pub enum ColorConversion {
73    /// RGB to YUV
74    RGBtoYUV,
75    /// YUV to RGB
76    YUVtoRGB,
77    /// RGB to HSV
78    RGBtoHSV,
79    /// HSV to RGB
80    HSVtoRGB,
81    /// RGB to Lab
82    RGBtoLab,
83    /// Lab to RGB
84    LabtoRGB,
85    /// sRGB to Linear RGB
86    SRGBtoLinear,
87    /// Linear RGB to sRGB
88    LinearToSRGB,
89}
90
91/// Color space conversion kernel
92pub struct ColorConversionKernel {
93    conversion: ColorConversion,
94    color_space: ColorSpace,
95}
96
97impl ColorConversionKernel {
98    /// Create a new color conversion kernel
99    #[must_use]
100    pub fn new(conversion: ColorConversion, color_space: ColorSpace) -> Self {
101        Self {
102            conversion,
103            color_space,
104        }
105    }
106
107    /// Create an RGB to YUV conversion kernel
108    #[must_use]
109    pub fn rgb_to_yuv(color_space: ColorSpace) -> Self {
110        Self::new(ColorConversion::RGBtoYUV, color_space)
111    }
112
113    /// Create a YUV to RGB conversion kernel
114    #[must_use]
115    pub fn yuv_to_rgb(color_space: ColorSpace) -> Self {
116        Self::new(ColorConversion::YUVtoRGB, color_space)
117    }
118
119    /// Execute the color conversion
120    ///
121    /// # Arguments
122    ///
123    /// * `device` - GPU device
124    /// * `input` - Input image buffer
125    /// * `output` - Output image buffer
126    /// * `width` - Image width
127    /// * `height` - Image height
128    ///
129    /// # Errors
130    ///
131    /// Returns an error if the conversion fails.
132    pub fn execute(
133        &self,
134        device: &GpuDevice,
135        input: &[u8],
136        output: &mut [u8],
137        width: u32,
138        height: u32,
139    ) -> Result<()> {
140        match self.conversion {
141            ColorConversion::RGBtoYUV => crate::ops::ColorSpaceConversion::rgb_to_yuv(
142                device,
143                input,
144                output,
145                width,
146                height,
147                self.color_space.into(),
148            ),
149            ColorConversion::YUVtoRGB => crate::ops::ColorSpaceConversion::yuv_to_rgb(
150                device,
151                input,
152                output,
153                width,
154                height,
155                self.color_space.into(),
156            ),
157            _ => {
158                // For other conversions, we would need to implement additional kernels
159                // For now, return an error
160                Err(crate::GpuError::NotSupported(format!(
161                    "Color conversion {:?} not yet implemented",
162                    self.conversion
163                )))
164            }
165        }
166    }
167
168    /// Get the conversion type
169    #[must_use]
170    pub fn conversion(&self) -> ColorConversion {
171        self.conversion
172    }
173
174    /// Get the color space
175    #[must_use]
176    pub fn color_space(&self) -> ColorSpace {
177        self.color_space
178    }
179
180    /// Calculate output buffer size
181    #[must_use]
182    pub fn output_size(width: u32, height: u32, channels: u32) -> usize {
183        (width * height * channels) as usize
184    }
185
186    /// Estimate FLOPS for color conversion
187    #[must_use]
188    pub fn estimate_flops(width: u32, height: u32, conversion: ColorConversion) -> u64 {
189        let pixels = u64::from(width) * u64::from(height);
190
191        match conversion {
192            ColorConversion::RGBtoYUV | ColorConversion::YUVtoRGB => {
193                // Matrix multiplication: 3x3 * 3 = 9 multiplies + 6 adds per pixel
194                pixels * 15
195            }
196            ColorConversion::RGBtoHSV | ColorConversion::HSVtoRGB => {
197                // HSV conversion involves min/max operations and divisions
198                pixels * 20
199            }
200            ColorConversion::RGBtoLab | ColorConversion::LabtoRGB => {
201                // Lab conversion is more complex with power functions
202                pixels * 50
203            }
204            ColorConversion::SRGBtoLinear | ColorConversion::LinearToSRGB => {
205                // Gamma correction per component
206                pixels * 3 * 5
207            }
208        }
209    }
210}
211
212/// Lookup table (LUT) based color transformation
213pub struct LutKernel {
214    lut_size: usize,
215}
216
217impl LutKernel {
218    /// Create a new LUT kernel
219    ///
220    /// # Arguments
221    ///
222    /// * `lut_size` - Size of the LUT (typically 256 for 1D or 33 for 3D)
223    #[must_use]
224    pub fn new(lut_size: usize) -> Self {
225        Self { lut_size }
226    }
227
228    /// Get the LUT size
229    #[must_use]
230    pub fn lut_size(&self) -> usize {
231        self.lut_size
232    }
233
234    /// Apply 1D LUT transformation (CPU fallback).
235    ///
236    /// The LUT layout is `[CH0_LUT[0..lut_size], CH1_LUT[0..lut_size], …]`.
237    /// For each pixel and each channel `c`, the output value is
238    /// `lut[c * lut_size + idx]` where `idx = (pixel_value * (lut_size-1)) / 255`.
239    ///
240    /// Pixels and the LUT are treated as having the same number of channels,
241    /// inferred from `lut.len() / lut_size`.  If `lut.len()` is not a multiple
242    /// of `lut_size`, the extra channel bytes are copied unchanged.
243    ///
244    /// # Arguments
245    ///
246    /// * `_device` - GPU device (CPU fallback: unused)
247    /// * `input` - Input image buffer
248    /// * `output` - Output image buffer (same length as `input`)
249    /// * `lut` - 1D lookup table (size: `lut_size * channels`)
250    /// * `_width` - Image width (unused; length is inferred from buffers)
251    /// * `_height` - Image height (unused; length is inferred from buffers)
252    ///
253    /// # Errors
254    ///
255    /// Returns an error if `lut_size` is zero or the LUT is empty.
256    #[allow(clippy::too_many_arguments)]
257    pub fn apply_1d(
258        &self,
259        _device: &GpuDevice,
260        input: &[u8],
261        output: &mut [u8],
262        lut: &[u8],
263        _width: u32,
264        _height: u32,
265    ) -> Result<()> {
266        if self.lut_size == 0 || lut.is_empty() {
267            return Err(crate::GpuError::NotSupported(
268                "1D LUT size must be non-zero".to_string(),
269            ));
270        }
271        let channels = lut.len() / self.lut_size;
272        if channels == 0 {
273            return Err(crate::GpuError::NotSupported(
274                "1D LUT must cover at least one channel".to_string(),
275            ));
276        }
277        let lut_max = self.lut_size - 1;
278        // Process each pixel (each group of `channels` bytes in the input).
279        // Any trailing bytes beyond a complete pixel are copied unchanged.
280        let full_pixels = input.len() / channels;
281        for px in 0..full_pixels {
282            let base = px * channels;
283            for c in 0..channels {
284                let pixel_val = input[base + c] as usize;
285                // Scale pixel value [0..=255] to lut index [0..=lut_max].
286                let lut_idx = (pixel_val * lut_max + 127) / 255; // round-to-nearest
287                let lut_idx = lut_idx.min(lut_max);
288                output[base + c] = lut[c * self.lut_size + lut_idx];
289            }
290        }
291        // Copy any trailing bytes (partial pixel) unchanged.
292        let tail_start = full_pixels * channels;
293        output[tail_start..input.len()].copy_from_slice(&input[tail_start..]);
294        Ok(())
295    }
296
297    /// Apply 3D LUT transformation with trilinear interpolation (CPU fallback).
298    ///
299    /// The LUT is a cubic grid of size `N × N × N` (where `N = lut_size`)
300    /// storing RGB triplets, laid out as
301    /// `lut[(r_idx * N*N + g_idx * N + b_idx) * 3 + channel]`.
302    ///
303    /// Input pixels are expected to be interleaved 3-channel (RGB) data.
304    /// Any extra channels beyond the first three are passed through unchanged.
305    ///
306    /// # Arguments
307    ///
308    /// * `_device` - GPU device (CPU fallback: unused)
309    /// * `input` - Input image buffer (interleaved RGB, 3 bytes per pixel minimum)
310    /// * `output` - Output image buffer (same length as `input`)
311    /// * `lut` - 3D LUT (`lut_size^3 * 3` f32 entries, values in `[0.0, 1.0]`)
312    /// * `_width` - Image width (unused; length is inferred from buffers)
313    /// * `_height` - Image height (unused; length is inferred from buffers)
314    ///
315    /// # Errors
316    ///
317    /// Returns an error if `lut_size` is zero or the LUT is too small.
318    #[allow(clippy::too_many_arguments)]
319    pub fn apply_3d(
320        &self,
321        _device: &GpuDevice,
322        input: &[u8],
323        output: &mut [u8],
324        lut: &[f32],
325        _width: u32,
326        _height: u32,
327    ) -> Result<()> {
328        let n = self.lut_size;
329        if n == 0 {
330            return Err(crate::GpuError::NotSupported(
331                "3D LUT size must be non-zero".to_string(),
332            ));
333        }
334        let expected_lut = n * n * n * 3;
335        if lut.len() < expected_lut {
336            return Err(crate::GpuError::NotSupported(format!(
337                "3D LUT too small: expected {expected_lut} entries, got {}",
338                lut.len()
339            )));
340        }
341
342        // Process pixels in groups of 3 (RGB).
343        let pixel_stride = 3usize;
344        let full_pixels = input.len() / pixel_stride;
345
346        for px in 0..full_pixels {
347            let base = px * pixel_stride;
348            // Normalize each channel to [0.0, 1.0].
349            let r = f32::from(input[base]) / 255.0;
350            let g = f32::from(input[base + 1]) / 255.0;
351            let b = f32::from(input[base + 2]) / 255.0;
352
353            // Compute fractional position in the LUT grid.
354            let nf = (n - 1) as f32;
355            let rx = r * nf;
356            let gy = g * nf;
357            let bz = b * nf;
358
359            // Integer cube corner indices.
360            let r0 = (rx.floor() as usize).min(n - 1);
361            let g0 = (gy.floor() as usize).min(n - 1);
362            let b0 = (bz.floor() as usize).min(n - 1);
363            let r1 = (r0 + 1).min(n - 1);
364            let g1 = (g0 + 1).min(n - 1);
365            let b1 = (b0 + 1).min(n - 1);
366
367            // Fractional parts for trilinear interpolation.
368            let fr = rx - r0 as f32;
369            let fg = gy - g0 as f32;
370            let fb = bz - b0 as f32;
371
372            // Helper closure: fetch one channel value from the LUT.
373            let lut_val = |ri: usize, gi: usize, bi: usize, ch: usize| -> f32 {
374                lut[(ri * n * n + gi * n + bi) * 3 + ch]
375            };
376
377            for ch in 0..3 {
378                // Trilinear interpolation over the 8 cube corners.
379                let c000 = lut_val(r0, g0, b0, ch);
380                let c100 = lut_val(r1, g0, b0, ch);
381                let c010 = lut_val(r0, g1, b0, ch);
382                let c110 = lut_val(r1, g1, b0, ch);
383                let c001 = lut_val(r0, g0, b1, ch);
384                let c101 = lut_val(r1, g0, b1, ch);
385                let c011 = lut_val(r0, g1, b1, ch);
386                let c111 = lut_val(r1, g1, b1, ch);
387
388                let c00 = c000 * (1.0 - fr) + c100 * fr;
389                let c01 = c001 * (1.0 - fr) + c101 * fr;
390                let c10 = c010 * (1.0 - fr) + c110 * fr;
391                let c11 = c011 * (1.0 - fr) + c111 * fr;
392
393                let c0 = c00 * (1.0 - fg) + c10 * fg;
394                let c1 = c01 * (1.0 - fg) + c11 * fg;
395
396                let val = c0 * (1.0 - fb) + c1 * fb;
397                output[base + ch] = (val.clamp(0.0, 1.0) * 255.0).round() as u8;
398            }
399        }
400
401        // Copy any trailing bytes (partial pixel) unchanged.
402        let tail_start = full_pixels * pixel_stride;
403        output[tail_start..input.len()].copy_from_slice(&input[tail_start..]);
404        Ok(())
405    }
406}
407
408#[cfg(test)]
409mod tests {
410    use super::*;
411
412    #[test]
413    fn test_color_space_properties() {
414        assert!(ColorSpace::YUV_BT601.is_yuv());
415        assert!(ColorSpace::YUV_BT709.is_yuv());
416        assert!(ColorSpace::YUV_BT2020.is_yuv());
417        assert!(!ColorSpace::RGB.is_yuv());
418
419        assert!(ColorSpace::RGB.is_rgb());
420        assert!(ColorSpace::LinearRGB.is_rgb());
421        assert!(ColorSpace::SRGB.is_rgb());
422        assert!(!ColorSpace::YUV_BT601.is_rgb());
423    }
424
425    #[test]
426    fn test_color_conversion_kernel() {
427        let kernel = ColorConversionKernel::rgb_to_yuv(ColorSpace::YUV_BT709);
428        assert_eq!(kernel.conversion(), ColorConversion::RGBtoYUV);
429        assert_eq!(kernel.color_space(), ColorSpace::YUV_BT709);
430    }
431
432    #[test]
433    fn test_flops_estimation() {
434        let flops = ColorConversionKernel::estimate_flops(1920, 1080, ColorConversion::RGBtoYUV);
435        assert!(flops > 0);
436
437        let flops_lab =
438            ColorConversionKernel::estimate_flops(1920, 1080, ColorConversion::RGBtoLab);
439        assert!(flops_lab > flops); // Lab conversion should be more expensive
440    }
441
442    // --- CPU LUT implementation tests (no GpuDevice required) ----------------
443
444    /// Build an identity 1D LUT for `channels` channels with `lut_size` entries.
445    fn identity_lut_1d(lut_size: usize, channels: usize) -> Vec<u8> {
446        let mut lut = vec![0u8; lut_size * channels];
447        for c in 0..channels {
448            for i in 0..lut_size {
449                // Scale i back to [0..=255].
450                lut[c * lut_size + i] = ((i * 255) / (lut_size - 1)) as u8;
451            }
452        }
453        lut
454    }
455
456    /// Build an identity 3D LUT for N×N×N grid (3 channels, values in [0,1]).
457    fn identity_lut_3d(n: usize) -> Vec<f32> {
458        let mut lut = vec![0.0f32; n * n * n * 3];
459        for ri in 0..n {
460            for gi in 0..n {
461                for bi in 0..n {
462                    let base = (ri * n * n + gi * n + bi) * 3;
463                    lut[base] = ri as f32 / (n - 1) as f32;
464                    lut[base + 1] = gi as f32 / (n - 1) as f32;
465                    lut[base + 2] = bi as f32 / (n - 1) as f32;
466                }
467            }
468        }
469        lut
470    }
471
472    #[test]
473    fn test_apply_1d_identity() {
474        // An identity LUT should reproduce the input pixel values.
475        let lut_size = 256usize;
476        let channels = 3usize;
477        let lut = identity_lut_1d(lut_size, channels);
478        let input: Vec<u8> = vec![0, 128, 255, 64, 192, 10];
479        let mut output = vec![0u8; input.len()];
480
481        // Run the logic inline (avoids GpuDevice construction).
482        let kernel = LutKernel::new(lut_size);
483        let lut_max = lut_size - 1;
484        let full_pixels = input.len() / channels;
485        for px in 0..full_pixels {
486            let base = px * channels;
487            for c in 0..channels {
488                let pixel_val = input[base + c] as usize;
489                let lut_idx = ((pixel_val * lut_max + 127) / 255).min(lut_max);
490                output[base + c] = lut[c * kernel.lut_size() + lut_idx];
491            }
492        }
493
494        // Each output byte should be very close to the corresponding input byte.
495        for (i, (&inp, &out)) in input.iter().zip(output.iter()).enumerate() {
496            let diff = inp as i32 - out as i32;
497            assert!(diff.abs() <= 1, "pixel {i}: input={inp}, output={out}");
498        }
499    }
500
501    #[test]
502    fn test_apply_1d_invert() {
503        // An inversion LUT: out = 255 - in.
504        let lut_size = 256usize;
505        let _channels = 1usize;
506        let lut: Vec<u8> = (0..lut_size).map(|i| (255 - i) as u8).collect();
507        let input: Vec<u8> = vec![0, 64, 128, 192, 255];
508        let mut output = vec![0u8; input.len()];
509
510        let lut_max = lut_size - 1;
511        for (i, &v) in input.iter().enumerate() {
512            let lut_idx = ((v as usize * lut_max + 127) / 255).min(lut_max);
513            output[i] = lut[lut_idx];
514        }
515
516        assert_eq!(output[0], 255);
517        assert_eq!(output[4], 0);
518    }
519
520    #[test]
521    fn test_apply_3d_identity() {
522        // An identity 3D LUT should reproduce the input pixel values (within rounding).
523        let n = 17usize; // common LUT size
524        let lut = identity_lut_3d(n);
525        let input: Vec<u8> = vec![0, 0, 0, 128, 64, 192, 255, 255, 255];
526        let mut output = vec![0u8; input.len()];
527
528        let nf = (n - 1) as f32;
529        let pixel_stride = 3usize;
530        let full_pixels = input.len() / pixel_stride;
531
532        for px in 0..full_pixels {
533            let base = px * pixel_stride;
534            let r = input[base] as f32 / 255.0;
535            let g = input[base + 1] as f32 / 255.0;
536            let b = input[base + 2] as f32 / 255.0;
537
538            let rx = r * nf;
539            let gy = g * nf;
540            let bz = b * nf;
541
542            let r0 = (rx.floor() as usize).min(n - 1);
543            let g0 = (gy.floor() as usize).min(n - 1);
544            let b0 = (bz.floor() as usize).min(n - 1);
545            let r1 = (r0 + 1).min(n - 1);
546            let g1 = (g0 + 1).min(n - 1);
547            let b1 = (b0 + 1).min(n - 1);
548            let fr = rx - r0 as f32;
549            let fg = gy - g0 as f32;
550            let fb = bz - b0 as f32;
551
552            for ch in 0..3 {
553                let lv = |ri: usize, gi: usize, bi: usize| -> f32 {
554                    lut[(ri * n * n + gi * n + bi) * 3 + ch]
555                };
556                let c000 = lv(r0, g0, b0);
557                let c100 = lv(r1, g0, b0);
558                let c010 = lv(r0, g1, b0);
559                let c110 = lv(r1, g1, b0);
560                let c001 = lv(r0, g0, b1);
561                let c101 = lv(r1, g0, b1);
562                let c011 = lv(r0, g1, b1);
563                let c111 = lv(r1, g1, b1);
564
565                let c00 = c000 * (1.0 - fr) + c100 * fr;
566                let c01 = c001 * (1.0 - fr) + c101 * fr;
567                let c10 = c010 * (1.0 - fr) + c110 * fr;
568                let c11 = c011 * (1.0 - fr) + c111 * fr;
569                let c0 = c00 * (1.0 - fg) + c10 * fg;
570                let c1 = c01 * (1.0 - fg) + c11 * fg;
571                let val = c0 * (1.0 - fb) + c1 * fb;
572                output[base + ch] = (val.clamp(0.0, 1.0) * 255.0).round() as u8;
573            }
574        }
575
576        // Each output should be within ±2 of the input (rounding in LUT grid).
577        for (i, (&inp, &out)) in input.iter().zip(output.iter()).enumerate() {
578            let diff = inp as i32 - out as i32;
579            assert!(
580                diff.abs() <= 2,
581                "channel byte {i}: input={inp}, output={out}"
582            );
583        }
584    }
585
586    #[test]
587    fn test_apply_3d_black_white() {
588        // Black (0,0,0) and white (255,255,255) corners should map exactly.
589        let n = 2usize; // minimal grid
590        let lut = identity_lut_3d(n);
591        let input: Vec<u8> = vec![0, 0, 0, 255, 255, 255];
592        let mut output = vec![0u8; 6];
593
594        let nf = (n - 1) as f32;
595        for px in 0..2usize {
596            let base = px * 3;
597            let r = input[base] as f32 / 255.0;
598            let g = input[base + 1] as f32 / 255.0;
599            let b = input[base + 2] as f32 / 255.0;
600            let rx = r * nf;
601            let gy = g * nf;
602            let bz = b * nf;
603            let r0 = (rx.floor() as usize).min(n - 1);
604            let g0 = (gy.floor() as usize).min(n - 1);
605            let b0 = (bz.floor() as usize).min(n - 1);
606            let r1 = (r0 + 1).min(n - 1);
607            let g1 = (g0 + 1).min(n - 1);
608            let b1 = (b0 + 1).min(n - 1);
609            let fr = rx - r0 as f32;
610            let fg = gy - g0 as f32;
611            let fb = bz - b0 as f32;
612            for ch in 0..3 {
613                let lv = |ri: usize, gi: usize, bi: usize| -> f32 {
614                    lut[(ri * n * n + gi * n + bi) * 3 + ch]
615                };
616                let c000 = lv(r0, g0, b0);
617                let c100 = lv(r1, g0, b0);
618                let c010 = lv(r0, g1, b0);
619                let c110 = lv(r1, g1, b0);
620                let c001 = lv(r0, g0, b1);
621                let c101 = lv(r1, g0, b1);
622                let c011 = lv(r0, g1, b1);
623                let c111 = lv(r1, g1, b1);
624                let c00 = c000 * (1.0 - fr) + c100 * fr;
625                let c01 = c001 * (1.0 - fr) + c101 * fr;
626                let c10 = c010 * (1.0 - fr) + c110 * fr;
627                let c11 = c011 * (1.0 - fr) + c111 * fr;
628                let c0 = c00 * (1.0 - fg) + c10 * fg;
629                let c1 = c01 * (1.0 - fg) + c11 * fg;
630                let val = c0 * (1.0 - fb) + c1 * fb;
631                output[base + ch] = (val.clamp(0.0, 1.0) * 255.0).round() as u8;
632            }
633        }
634
635        // Black should remain black.
636        assert_eq!(&output[0..3], &[0u8, 0, 0]);
637        // White should remain white.
638        assert_eq!(&output[3..6], &[255u8, 255, 255]);
639    }
640}