1use crate::{GpuDevice, Result};
4
5#[derive(Debug, Clone, Copy, PartialEq, Eq)]
7#[allow(non_camel_case_types)]
8pub enum ColorSpace {
9 RGB,
11 YUV_BT601,
13 YUV_BT709,
15 YUV_BT2020,
17 HSV,
19 HSL,
21 Lab,
23 LinearRGB,
25 SRGB,
27}
28
29impl ColorSpace {
30 #[must_use]
32 pub fn is_yuv(self) -> bool {
33 matches!(self, Self::YUV_BT601 | Self::YUV_BT709 | Self::YUV_BT2020)
34 }
35
36 #[must_use]
38 pub fn is_rgb(self) -> bool {
39 matches!(self, Self::RGB | Self::LinearRGB | Self::SRGB)
40 }
41
42 #[must_use]
44 pub fn name(self) -> &'static str {
45 match self {
46 Self::RGB => "RGB",
47 Self::YUV_BT601 => "YUV (BT.601)",
48 Self::YUV_BT709 => "YUV (BT.709)",
49 Self::YUV_BT2020 => "YUV (BT.2020)",
50 Self::HSV => "HSV",
51 Self::HSL => "HSL",
52 Self::Lab => "CIE Lab",
53 Self::LinearRGB => "Linear RGB",
54 Self::SRGB => "sRGB",
55 }
56 }
57}
58
59impl From<ColorSpace> for crate::ops::ColorSpace {
60 fn from(space: ColorSpace) -> Self {
61 match space {
62 ColorSpace::YUV_BT601 | ColorSpace::RGB => Self::BT601,
63 ColorSpace::YUV_BT709 => Self::BT709,
64 ColorSpace::YUV_BT2020 => Self::BT2020,
65 _ => Self::BT601, }
67 }
68}
69
70#[derive(Debug, Clone, Copy, PartialEq, Eq)]
72pub enum ColorConversion {
73 RGBtoYUV,
75 YUVtoRGB,
77 RGBtoHSV,
79 HSVtoRGB,
81 RGBtoLab,
83 LabtoRGB,
85 SRGBtoLinear,
87 LinearToSRGB,
89}
90
91pub struct ColorConversionKernel {
93 conversion: ColorConversion,
94 color_space: ColorSpace,
95}
96
97impl ColorConversionKernel {
98 #[must_use]
100 pub fn new(conversion: ColorConversion, color_space: ColorSpace) -> Self {
101 Self {
102 conversion,
103 color_space,
104 }
105 }
106
107 #[must_use]
109 pub fn rgb_to_yuv(color_space: ColorSpace) -> Self {
110 Self::new(ColorConversion::RGBtoYUV, color_space)
111 }
112
113 #[must_use]
115 pub fn yuv_to_rgb(color_space: ColorSpace) -> Self {
116 Self::new(ColorConversion::YUVtoRGB, color_space)
117 }
118
119 pub fn execute(
133 &self,
134 device: &GpuDevice,
135 input: &[u8],
136 output: &mut [u8],
137 width: u32,
138 height: u32,
139 ) -> Result<()> {
140 match self.conversion {
141 ColorConversion::RGBtoYUV => crate::ops::ColorSpaceConversion::rgb_to_yuv(
142 device,
143 input,
144 output,
145 width,
146 height,
147 self.color_space.into(),
148 ),
149 ColorConversion::YUVtoRGB => crate::ops::ColorSpaceConversion::yuv_to_rgb(
150 device,
151 input,
152 output,
153 width,
154 height,
155 self.color_space.into(),
156 ),
157 ColorConversion::RGBtoHSV => {
158 let result = crate::ops::ColorSpaceConversion::rgb_to_hsv(input, width, height);
159 let copy_len = result.len().min(output.len());
160 output[..copy_len].copy_from_slice(&result[..copy_len]);
161 Ok(())
162 }
163 ColorConversion::HSVtoRGB => {
164 let result = crate::ops::ColorSpaceConversion::hsv_to_rgb(input, width, height);
165 let copy_len = result.len().min(output.len());
166 output[..copy_len].copy_from_slice(&result[..copy_len]);
167 Ok(())
168 }
169 ColorConversion::RGBtoLab => {
170 let result = crate::ops::ColorSpaceConversion::rgb_to_lab(input, width, height);
171 let copy_len = result.len().min(output.len());
172 output[..copy_len].copy_from_slice(&result[..copy_len]);
173 Ok(())
174 }
175 ColorConversion::LabtoRGB => {
176 let result = crate::ops::ColorSpaceConversion::lab_to_rgb(input, width, height);
177 let copy_len = result.len().min(output.len());
178 output[..copy_len].copy_from_slice(&result[..copy_len]);
179 Ok(())
180 }
181 ColorConversion::SRGBtoLinear => {
182 let result = crate::ops::ColorSpaceConversion::srgb_to_linear(input, width, height);
183 let copy_len = result.len().min(output.len());
184 output[..copy_len].copy_from_slice(&result[..copy_len]);
185 Ok(())
186 }
187 ColorConversion::LinearToSRGB => {
188 let result = crate::ops::ColorSpaceConversion::linear_to_srgb(input, width, height);
189 let copy_len = result.len().min(output.len());
190 output[..copy_len].copy_from_slice(&result[..copy_len]);
191 Ok(())
192 }
193 }
194 }
195
196 #[must_use]
198 pub fn conversion(&self) -> ColorConversion {
199 self.conversion
200 }
201
202 #[must_use]
204 pub fn color_space(&self) -> ColorSpace {
205 self.color_space
206 }
207
208 #[must_use]
210 pub fn output_size(width: u32, height: u32, channels: u32) -> usize {
211 (width * height * channels) as usize
212 }
213
214 #[must_use]
216 pub fn estimate_flops(width: u32, height: u32, conversion: ColorConversion) -> u64 {
217 let pixels = u64::from(width) * u64::from(height);
218
219 match conversion {
220 ColorConversion::RGBtoYUV | ColorConversion::YUVtoRGB => {
221 pixels * 15
223 }
224 ColorConversion::RGBtoHSV | ColorConversion::HSVtoRGB => {
225 pixels * 20
227 }
228 ColorConversion::RGBtoLab | ColorConversion::LabtoRGB => {
229 pixels * 50
231 }
232 ColorConversion::SRGBtoLinear | ColorConversion::LinearToSRGB => {
233 pixels * 3 * 5
235 }
236 }
237 }
238}
239
240pub struct LutKernel {
242 lut_size: usize,
243}
244
245impl LutKernel {
246 #[must_use]
252 pub fn new(lut_size: usize) -> Self {
253 Self { lut_size }
254 }
255
256 #[must_use]
258 pub fn lut_size(&self) -> usize {
259 self.lut_size
260 }
261
262 #[allow(clippy::too_many_arguments)]
285 pub fn apply_1d(
286 &self,
287 _device: &GpuDevice,
288 input: &[u8],
289 output: &mut [u8],
290 lut: &[u8],
291 _width: u32,
292 _height: u32,
293 ) -> Result<()> {
294 if self.lut_size == 0 || lut.is_empty() {
295 return Err(crate::GpuError::NotSupported(
296 "1D LUT size must be non-zero".to_string(),
297 ));
298 }
299 let channels = lut.len() / self.lut_size;
300 if channels == 0 {
301 return Err(crate::GpuError::NotSupported(
302 "1D LUT must cover at least one channel".to_string(),
303 ));
304 }
305 let lut_max = self.lut_size - 1;
306 let full_pixels = input.len() / channels;
309 for px in 0..full_pixels {
310 let base = px * channels;
311 for c in 0..channels {
312 let pixel_val = input[base + c] as usize;
313 let lut_idx = (pixel_val * lut_max + 127) / 255; let lut_idx = lut_idx.min(lut_max);
316 output[base + c] = lut[c * self.lut_size + lut_idx];
317 }
318 }
319 let tail_start = full_pixels * channels;
321 output[tail_start..input.len()].copy_from_slice(&input[tail_start..]);
322 Ok(())
323 }
324
325 #[allow(clippy::too_many_arguments)]
347 pub fn apply_3d(
348 &self,
349 _device: &GpuDevice,
350 input: &[u8],
351 output: &mut [u8],
352 lut: &[f32],
353 _width: u32,
354 _height: u32,
355 ) -> Result<()> {
356 let n = self.lut_size;
357 if n == 0 {
358 return Err(crate::GpuError::NotSupported(
359 "3D LUT size must be non-zero".to_string(),
360 ));
361 }
362 let expected_lut = n * n * n * 3;
363 if lut.len() < expected_lut {
364 return Err(crate::GpuError::NotSupported(format!(
365 "3D LUT too small: expected {expected_lut} entries, got {}",
366 lut.len()
367 )));
368 }
369
370 let pixel_stride = 3usize;
372 let full_pixels = input.len() / pixel_stride;
373
374 for px in 0..full_pixels {
375 let base = px * pixel_stride;
376 let r = f32::from(input[base]) / 255.0;
378 let g = f32::from(input[base + 1]) / 255.0;
379 let b = f32::from(input[base + 2]) / 255.0;
380
381 let nf = (n - 1) as f32;
383 let rx = r * nf;
384 let gy = g * nf;
385 let bz = b * nf;
386
387 let r0 = (rx.floor() as usize).min(n - 1);
389 let g0 = (gy.floor() as usize).min(n - 1);
390 let b0 = (bz.floor() as usize).min(n - 1);
391 let r1 = (r0 + 1).min(n - 1);
392 let g1 = (g0 + 1).min(n - 1);
393 let b1 = (b0 + 1).min(n - 1);
394
395 let fr = rx - r0 as f32;
397 let fg = gy - g0 as f32;
398 let fb = bz - b0 as f32;
399
400 let lut_val = |ri: usize, gi: usize, bi: usize, ch: usize| -> f32 {
402 lut[(ri * n * n + gi * n + bi) * 3 + ch]
403 };
404
405 for ch in 0..3 {
406 let c000 = lut_val(r0, g0, b0, ch);
408 let c100 = lut_val(r1, g0, b0, ch);
409 let c010 = lut_val(r0, g1, b0, ch);
410 let c110 = lut_val(r1, g1, b0, ch);
411 let c001 = lut_val(r0, g0, b1, ch);
412 let c101 = lut_val(r1, g0, b1, ch);
413 let c011 = lut_val(r0, g1, b1, ch);
414 let c111 = lut_val(r1, g1, b1, ch);
415
416 let c00 = c000 * (1.0 - fr) + c100 * fr;
417 let c01 = c001 * (1.0 - fr) + c101 * fr;
418 let c10 = c010 * (1.0 - fr) + c110 * fr;
419 let c11 = c011 * (1.0 - fr) + c111 * fr;
420
421 let c0 = c00 * (1.0 - fg) + c10 * fg;
422 let c1 = c01 * (1.0 - fg) + c11 * fg;
423
424 let val = c0 * (1.0 - fb) + c1 * fb;
425 output[base + ch] = (val.clamp(0.0, 1.0) * 255.0).round() as u8;
426 }
427 }
428
429 let tail_start = full_pixels * pixel_stride;
431 output[tail_start..input.len()].copy_from_slice(&input[tail_start..]);
432 Ok(())
433 }
434}
435
436#[cfg(test)]
437mod tests {
438 use super::*;
439
440 #[test]
441 fn test_color_space_properties() {
442 assert!(ColorSpace::YUV_BT601.is_yuv());
443 assert!(ColorSpace::YUV_BT709.is_yuv());
444 assert!(ColorSpace::YUV_BT2020.is_yuv());
445 assert!(!ColorSpace::RGB.is_yuv());
446
447 assert!(ColorSpace::RGB.is_rgb());
448 assert!(ColorSpace::LinearRGB.is_rgb());
449 assert!(ColorSpace::SRGB.is_rgb());
450 assert!(!ColorSpace::YUV_BT601.is_rgb());
451 }
452
453 #[test]
454 fn test_color_conversion_kernel() {
455 let kernel = ColorConversionKernel::rgb_to_yuv(ColorSpace::YUV_BT709);
456 assert_eq!(kernel.conversion(), ColorConversion::RGBtoYUV);
457 assert_eq!(kernel.color_space(), ColorSpace::YUV_BT709);
458 }
459
460 #[test]
461 fn test_flops_estimation() {
462 let flops = ColorConversionKernel::estimate_flops(1920, 1080, ColorConversion::RGBtoYUV);
463 assert!(flops > 0);
464
465 let flops_lab =
466 ColorConversionKernel::estimate_flops(1920, 1080, ColorConversion::RGBtoLab);
467 assert!(flops_lab > flops); }
469
470 fn identity_lut_1d(lut_size: usize, channels: usize) -> Vec<u8> {
474 let mut lut = vec![0u8; lut_size * channels];
475 for c in 0..channels {
476 for i in 0..lut_size {
477 lut[c * lut_size + i] = ((i * 255) / (lut_size - 1)) as u8;
479 }
480 }
481 lut
482 }
483
484 fn identity_lut_3d(n: usize) -> Vec<f32> {
486 let mut lut = vec![0.0f32; n * n * n * 3];
487 for ri in 0..n {
488 for gi in 0..n {
489 for bi in 0..n {
490 let base = (ri * n * n + gi * n + bi) * 3;
491 lut[base] = ri as f32 / (n - 1) as f32;
492 lut[base + 1] = gi as f32 / (n - 1) as f32;
493 lut[base + 2] = bi as f32 / (n - 1) as f32;
494 }
495 }
496 }
497 lut
498 }
499
500 #[test]
501 fn test_apply_1d_identity() {
502 let lut_size = 256usize;
504 let channels = 3usize;
505 let lut = identity_lut_1d(lut_size, channels);
506 let input: Vec<u8> = vec![0, 128, 255, 64, 192, 10];
507 let mut output = vec![0u8; input.len()];
508
509 let kernel = LutKernel::new(lut_size);
511 let lut_max = lut_size - 1;
512 let full_pixels = input.len() / channels;
513 for px in 0..full_pixels {
514 let base = px * channels;
515 for c in 0..channels {
516 let pixel_val = input[base + c] as usize;
517 let lut_idx = ((pixel_val * lut_max + 127) / 255).min(lut_max);
518 output[base + c] = lut[c * kernel.lut_size() + lut_idx];
519 }
520 }
521
522 for (i, (&inp, &out)) in input.iter().zip(output.iter()).enumerate() {
524 let diff = inp as i32 - out as i32;
525 assert!(diff.abs() <= 1, "pixel {i}: input={inp}, output={out}");
526 }
527 }
528
529 #[test]
530 fn test_apply_1d_invert() {
531 let lut_size = 256usize;
533 let _channels = 1usize;
534 let lut: Vec<u8> = (0..lut_size).map(|i| (255 - i) as u8).collect();
535 let input: Vec<u8> = vec![0, 64, 128, 192, 255];
536 let mut output = vec![0u8; input.len()];
537
538 let lut_max = lut_size - 1;
539 for (i, &v) in input.iter().enumerate() {
540 let lut_idx = ((v as usize * lut_max + 127) / 255).min(lut_max);
541 output[i] = lut[lut_idx];
542 }
543
544 assert_eq!(output[0], 255);
545 assert_eq!(output[4], 0);
546 }
547
548 #[test]
549 fn test_apply_3d_identity() {
550 let n = 17usize; let lut = identity_lut_3d(n);
553 let input: Vec<u8> = vec![0, 0, 0, 128, 64, 192, 255, 255, 255];
554 let mut output = vec![0u8; input.len()];
555
556 let nf = (n - 1) as f32;
557 let pixel_stride = 3usize;
558 let full_pixels = input.len() / pixel_stride;
559
560 for px in 0..full_pixels {
561 let base = px * pixel_stride;
562 let r = input[base] as f32 / 255.0;
563 let g = input[base + 1] as f32 / 255.0;
564 let b = input[base + 2] as f32 / 255.0;
565
566 let rx = r * nf;
567 let gy = g * nf;
568 let bz = b * nf;
569
570 let r0 = (rx.floor() as usize).min(n - 1);
571 let g0 = (gy.floor() as usize).min(n - 1);
572 let b0 = (bz.floor() as usize).min(n - 1);
573 let r1 = (r0 + 1).min(n - 1);
574 let g1 = (g0 + 1).min(n - 1);
575 let b1 = (b0 + 1).min(n - 1);
576 let fr = rx - r0 as f32;
577 let fg = gy - g0 as f32;
578 let fb = bz - b0 as f32;
579
580 for ch in 0..3 {
581 let lv = |ri: usize, gi: usize, bi: usize| -> f32 {
582 lut[(ri * n * n + gi * n + bi) * 3 + ch]
583 };
584 let c000 = lv(r0, g0, b0);
585 let c100 = lv(r1, g0, b0);
586 let c010 = lv(r0, g1, b0);
587 let c110 = lv(r1, g1, b0);
588 let c001 = lv(r0, g0, b1);
589 let c101 = lv(r1, g0, b1);
590 let c011 = lv(r0, g1, b1);
591 let c111 = lv(r1, g1, b1);
592
593 let c00 = c000 * (1.0 - fr) + c100 * fr;
594 let c01 = c001 * (1.0 - fr) + c101 * fr;
595 let c10 = c010 * (1.0 - fr) + c110 * fr;
596 let c11 = c011 * (1.0 - fr) + c111 * fr;
597 let c0 = c00 * (1.0 - fg) + c10 * fg;
598 let c1 = c01 * (1.0 - fg) + c11 * fg;
599 let val = c0 * (1.0 - fb) + c1 * fb;
600 output[base + ch] = (val.clamp(0.0, 1.0) * 255.0).round() as u8;
601 }
602 }
603
604 for (i, (&inp, &out)) in input.iter().zip(output.iter()).enumerate() {
606 let diff = inp as i32 - out as i32;
607 assert!(
608 diff.abs() <= 2,
609 "channel byte {i}: input={inp}, output={out}"
610 );
611 }
612 }
613
614 #[test]
615 fn test_apply_3d_black_white() {
616 let n = 2usize; let lut = identity_lut_3d(n);
619 let input: Vec<u8> = vec![0, 0, 0, 255, 255, 255];
620 let mut output = vec![0u8; 6];
621
622 let nf = (n - 1) as f32;
623 for px in 0..2usize {
624 let base = px * 3;
625 let r = input[base] as f32 / 255.0;
626 let g = input[base + 1] as f32 / 255.0;
627 let b = input[base + 2] as f32 / 255.0;
628 let rx = r * nf;
629 let gy = g * nf;
630 let bz = b * nf;
631 let r0 = (rx.floor() as usize).min(n - 1);
632 let g0 = (gy.floor() as usize).min(n - 1);
633 let b0 = (bz.floor() as usize).min(n - 1);
634 let r1 = (r0 + 1).min(n - 1);
635 let g1 = (g0 + 1).min(n - 1);
636 let b1 = (b0 + 1).min(n - 1);
637 let fr = rx - r0 as f32;
638 let fg = gy - g0 as f32;
639 let fb = bz - b0 as f32;
640 for ch in 0..3 {
641 let lv = |ri: usize, gi: usize, bi: usize| -> f32 {
642 lut[(ri * n * n + gi * n + bi) * 3 + ch]
643 };
644 let c000 = lv(r0, g0, b0);
645 let c100 = lv(r1, g0, b0);
646 let c010 = lv(r0, g1, b0);
647 let c110 = lv(r1, g1, b0);
648 let c001 = lv(r0, g0, b1);
649 let c101 = lv(r1, g0, b1);
650 let c011 = lv(r0, g1, b1);
651 let c111 = lv(r1, g1, b1);
652 let c00 = c000 * (1.0 - fr) + c100 * fr;
653 let c01 = c001 * (1.0 - fr) + c101 * fr;
654 let c10 = c010 * (1.0 - fr) + c110 * fr;
655 let c11 = c011 * (1.0 - fr) + c111 * fr;
656 let c0 = c00 * (1.0 - fg) + c10 * fg;
657 let c1 = c01 * (1.0 - fg) + c11 * fg;
658 let val = c0 * (1.0 - fb) + c1 * fb;
659 output[base + ch] = (val.clamp(0.0, 1.0) * 255.0).round() as u8;
660 }
661 }
662
663 assert_eq!(&output[0..3], &[0u8, 0, 0]);
665 assert_eq!(&output[3..6], &[255u8, 255, 255]);
667 }
668}