1use crate::{GpuDevice, Result};
4
5#[derive(Debug, Clone, Copy, PartialEq, Eq)]
7#[allow(non_camel_case_types)]
8pub enum ColorSpace {
9 RGB,
11 YUV_BT601,
13 YUV_BT709,
15 YUV_BT2020,
17 HSV,
19 HSL,
21 Lab,
23 LinearRGB,
25 SRGB,
27}
28
29impl ColorSpace {
30 #[must_use]
32 pub fn is_yuv(self) -> bool {
33 matches!(self, Self::YUV_BT601 | Self::YUV_BT709 | Self::YUV_BT2020)
34 }
35
36 #[must_use]
38 pub fn is_rgb(self) -> bool {
39 matches!(self, Self::RGB | Self::LinearRGB | Self::SRGB)
40 }
41
42 #[must_use]
44 pub fn name(self) -> &'static str {
45 match self {
46 Self::RGB => "RGB",
47 Self::YUV_BT601 => "YUV (BT.601)",
48 Self::YUV_BT709 => "YUV (BT.709)",
49 Self::YUV_BT2020 => "YUV (BT.2020)",
50 Self::HSV => "HSV",
51 Self::HSL => "HSL",
52 Self::Lab => "CIE Lab",
53 Self::LinearRGB => "Linear RGB",
54 Self::SRGB => "sRGB",
55 }
56 }
57}
58
59impl From<ColorSpace> for crate::ops::ColorSpace {
60 fn from(space: ColorSpace) -> Self {
61 match space {
62 ColorSpace::YUV_BT601 | ColorSpace::RGB => Self::BT601,
63 ColorSpace::YUV_BT709 => Self::BT709,
64 ColorSpace::YUV_BT2020 => Self::BT2020,
65 _ => Self::BT601, }
67 }
68}
69
70#[derive(Debug, Clone, Copy, PartialEq, Eq)]
72pub enum ColorConversion {
73 RGBtoYUV,
75 YUVtoRGB,
77 RGBtoHSV,
79 HSVtoRGB,
81 RGBtoLab,
83 LabtoRGB,
85 SRGBtoLinear,
87 LinearToSRGB,
89}
90
91pub struct ColorConversionKernel {
93 conversion: ColorConversion,
94 color_space: ColorSpace,
95}
96
97impl ColorConversionKernel {
98 #[must_use]
100 pub fn new(conversion: ColorConversion, color_space: ColorSpace) -> Self {
101 Self {
102 conversion,
103 color_space,
104 }
105 }
106
107 #[must_use]
109 pub fn rgb_to_yuv(color_space: ColorSpace) -> Self {
110 Self::new(ColorConversion::RGBtoYUV, color_space)
111 }
112
113 #[must_use]
115 pub fn yuv_to_rgb(color_space: ColorSpace) -> Self {
116 Self::new(ColorConversion::YUVtoRGB, color_space)
117 }
118
119 pub fn execute(
133 &self,
134 device: &GpuDevice,
135 input: &[u8],
136 output: &mut [u8],
137 width: u32,
138 height: u32,
139 ) -> Result<()> {
140 match self.conversion {
141 ColorConversion::RGBtoYUV => crate::ops::ColorSpaceConversion::rgb_to_yuv(
142 device,
143 input,
144 output,
145 width,
146 height,
147 self.color_space.into(),
148 ),
149 ColorConversion::YUVtoRGB => crate::ops::ColorSpaceConversion::yuv_to_rgb(
150 device,
151 input,
152 output,
153 width,
154 height,
155 self.color_space.into(),
156 ),
157 _ => {
158 Err(crate::GpuError::NotSupported(format!(
161 "Color conversion {:?} not yet implemented",
162 self.conversion
163 )))
164 }
165 }
166 }
167
168 #[must_use]
170 pub fn conversion(&self) -> ColorConversion {
171 self.conversion
172 }
173
174 #[must_use]
176 pub fn color_space(&self) -> ColorSpace {
177 self.color_space
178 }
179
180 #[must_use]
182 pub fn output_size(width: u32, height: u32, channels: u32) -> usize {
183 (width * height * channels) as usize
184 }
185
186 #[must_use]
188 pub fn estimate_flops(width: u32, height: u32, conversion: ColorConversion) -> u64 {
189 let pixels = u64::from(width) * u64::from(height);
190
191 match conversion {
192 ColorConversion::RGBtoYUV | ColorConversion::YUVtoRGB => {
193 pixels * 15
195 }
196 ColorConversion::RGBtoHSV | ColorConversion::HSVtoRGB => {
197 pixels * 20
199 }
200 ColorConversion::RGBtoLab | ColorConversion::LabtoRGB => {
201 pixels * 50
203 }
204 ColorConversion::SRGBtoLinear | ColorConversion::LinearToSRGB => {
205 pixels * 3 * 5
207 }
208 }
209 }
210}
211
212pub struct LutKernel {
214 lut_size: usize,
215}
216
217impl LutKernel {
218 #[must_use]
224 pub fn new(lut_size: usize) -> Self {
225 Self { lut_size }
226 }
227
228 #[must_use]
230 pub fn lut_size(&self) -> usize {
231 self.lut_size
232 }
233
234 #[allow(clippy::too_many_arguments)]
257 pub fn apply_1d(
258 &self,
259 _device: &GpuDevice,
260 input: &[u8],
261 output: &mut [u8],
262 lut: &[u8],
263 _width: u32,
264 _height: u32,
265 ) -> Result<()> {
266 if self.lut_size == 0 || lut.is_empty() {
267 return Err(crate::GpuError::NotSupported(
268 "1D LUT size must be non-zero".to_string(),
269 ));
270 }
271 let channels = lut.len() / self.lut_size;
272 if channels == 0 {
273 return Err(crate::GpuError::NotSupported(
274 "1D LUT must cover at least one channel".to_string(),
275 ));
276 }
277 let lut_max = self.lut_size - 1;
278 let full_pixels = input.len() / channels;
281 for px in 0..full_pixels {
282 let base = px * channels;
283 for c in 0..channels {
284 let pixel_val = input[base + c] as usize;
285 let lut_idx = (pixel_val * lut_max + 127) / 255; let lut_idx = lut_idx.min(lut_max);
288 output[base + c] = lut[c * self.lut_size + lut_idx];
289 }
290 }
291 let tail_start = full_pixels * channels;
293 output[tail_start..input.len()].copy_from_slice(&input[tail_start..]);
294 Ok(())
295 }
296
297 #[allow(clippy::too_many_arguments)]
319 pub fn apply_3d(
320 &self,
321 _device: &GpuDevice,
322 input: &[u8],
323 output: &mut [u8],
324 lut: &[f32],
325 _width: u32,
326 _height: u32,
327 ) -> Result<()> {
328 let n = self.lut_size;
329 if n == 0 {
330 return Err(crate::GpuError::NotSupported(
331 "3D LUT size must be non-zero".to_string(),
332 ));
333 }
334 let expected_lut = n * n * n * 3;
335 if lut.len() < expected_lut {
336 return Err(crate::GpuError::NotSupported(format!(
337 "3D LUT too small: expected {expected_lut} entries, got {}",
338 lut.len()
339 )));
340 }
341
342 let pixel_stride = 3usize;
344 let full_pixels = input.len() / pixel_stride;
345
346 for px in 0..full_pixels {
347 let base = px * pixel_stride;
348 let r = f32::from(input[base]) / 255.0;
350 let g = f32::from(input[base + 1]) / 255.0;
351 let b = f32::from(input[base + 2]) / 255.0;
352
353 let nf = (n - 1) as f32;
355 let rx = r * nf;
356 let gy = g * nf;
357 let bz = b * nf;
358
359 let r0 = (rx.floor() as usize).min(n - 1);
361 let g0 = (gy.floor() as usize).min(n - 1);
362 let b0 = (bz.floor() as usize).min(n - 1);
363 let r1 = (r0 + 1).min(n - 1);
364 let g1 = (g0 + 1).min(n - 1);
365 let b1 = (b0 + 1).min(n - 1);
366
367 let fr = rx - r0 as f32;
369 let fg = gy - g0 as f32;
370 let fb = bz - b0 as f32;
371
372 let lut_val = |ri: usize, gi: usize, bi: usize, ch: usize| -> f32 {
374 lut[(ri * n * n + gi * n + bi) * 3 + ch]
375 };
376
377 for ch in 0..3 {
378 let c000 = lut_val(r0, g0, b0, ch);
380 let c100 = lut_val(r1, g0, b0, ch);
381 let c010 = lut_val(r0, g1, b0, ch);
382 let c110 = lut_val(r1, g1, b0, ch);
383 let c001 = lut_val(r0, g0, b1, ch);
384 let c101 = lut_val(r1, g0, b1, ch);
385 let c011 = lut_val(r0, g1, b1, ch);
386 let c111 = lut_val(r1, g1, b1, ch);
387
388 let c00 = c000 * (1.0 - fr) + c100 * fr;
389 let c01 = c001 * (1.0 - fr) + c101 * fr;
390 let c10 = c010 * (1.0 - fr) + c110 * fr;
391 let c11 = c011 * (1.0 - fr) + c111 * fr;
392
393 let c0 = c00 * (1.0 - fg) + c10 * fg;
394 let c1 = c01 * (1.0 - fg) + c11 * fg;
395
396 let val = c0 * (1.0 - fb) + c1 * fb;
397 output[base + ch] = (val.clamp(0.0, 1.0) * 255.0).round() as u8;
398 }
399 }
400
401 let tail_start = full_pixels * pixel_stride;
403 output[tail_start..input.len()].copy_from_slice(&input[tail_start..]);
404 Ok(())
405 }
406}
407
408#[cfg(test)]
409mod tests {
410 use super::*;
411
412 #[test]
413 fn test_color_space_properties() {
414 assert!(ColorSpace::YUV_BT601.is_yuv());
415 assert!(ColorSpace::YUV_BT709.is_yuv());
416 assert!(ColorSpace::YUV_BT2020.is_yuv());
417 assert!(!ColorSpace::RGB.is_yuv());
418
419 assert!(ColorSpace::RGB.is_rgb());
420 assert!(ColorSpace::LinearRGB.is_rgb());
421 assert!(ColorSpace::SRGB.is_rgb());
422 assert!(!ColorSpace::YUV_BT601.is_rgb());
423 }
424
425 #[test]
426 fn test_color_conversion_kernel() {
427 let kernel = ColorConversionKernel::rgb_to_yuv(ColorSpace::YUV_BT709);
428 assert_eq!(kernel.conversion(), ColorConversion::RGBtoYUV);
429 assert_eq!(kernel.color_space(), ColorSpace::YUV_BT709);
430 }
431
432 #[test]
433 fn test_flops_estimation() {
434 let flops = ColorConversionKernel::estimate_flops(1920, 1080, ColorConversion::RGBtoYUV);
435 assert!(flops > 0);
436
437 let flops_lab =
438 ColorConversionKernel::estimate_flops(1920, 1080, ColorConversion::RGBtoLab);
439 assert!(flops_lab > flops); }
441
442 fn identity_lut_1d(lut_size: usize, channels: usize) -> Vec<u8> {
446 let mut lut = vec![0u8; lut_size * channels];
447 for c in 0..channels {
448 for i in 0..lut_size {
449 lut[c * lut_size + i] = ((i * 255) / (lut_size - 1)) as u8;
451 }
452 }
453 lut
454 }
455
456 fn identity_lut_3d(n: usize) -> Vec<f32> {
458 let mut lut = vec![0.0f32; n * n * n * 3];
459 for ri in 0..n {
460 for gi in 0..n {
461 for bi in 0..n {
462 let base = (ri * n * n + gi * n + bi) * 3;
463 lut[base] = ri as f32 / (n - 1) as f32;
464 lut[base + 1] = gi as f32 / (n - 1) as f32;
465 lut[base + 2] = bi as f32 / (n - 1) as f32;
466 }
467 }
468 }
469 lut
470 }
471
472 #[test]
473 fn test_apply_1d_identity() {
474 let lut_size = 256usize;
476 let channels = 3usize;
477 let lut = identity_lut_1d(lut_size, channels);
478 let input: Vec<u8> = vec![0, 128, 255, 64, 192, 10];
479 let mut output = vec![0u8; input.len()];
480
481 let kernel = LutKernel::new(lut_size);
483 let lut_max = lut_size - 1;
484 let full_pixels = input.len() / channels;
485 for px in 0..full_pixels {
486 let base = px * channels;
487 for c in 0..channels {
488 let pixel_val = input[base + c] as usize;
489 let lut_idx = ((pixel_val * lut_max + 127) / 255).min(lut_max);
490 output[base + c] = lut[c * kernel.lut_size() + lut_idx];
491 }
492 }
493
494 for (i, (&inp, &out)) in input.iter().zip(output.iter()).enumerate() {
496 let diff = inp as i32 - out as i32;
497 assert!(diff.abs() <= 1, "pixel {i}: input={inp}, output={out}");
498 }
499 }
500
501 #[test]
502 fn test_apply_1d_invert() {
503 let lut_size = 256usize;
505 let _channels = 1usize;
506 let lut: Vec<u8> = (0..lut_size).map(|i| (255 - i) as u8).collect();
507 let input: Vec<u8> = vec![0, 64, 128, 192, 255];
508 let mut output = vec![0u8; input.len()];
509
510 let lut_max = lut_size - 1;
511 for (i, &v) in input.iter().enumerate() {
512 let lut_idx = ((v as usize * lut_max + 127) / 255).min(lut_max);
513 output[i] = lut[lut_idx];
514 }
515
516 assert_eq!(output[0], 255);
517 assert_eq!(output[4], 0);
518 }
519
520 #[test]
521 fn test_apply_3d_identity() {
522 let n = 17usize; let lut = identity_lut_3d(n);
525 let input: Vec<u8> = vec![0, 0, 0, 128, 64, 192, 255, 255, 255];
526 let mut output = vec![0u8; input.len()];
527
528 let nf = (n - 1) as f32;
529 let pixel_stride = 3usize;
530 let full_pixels = input.len() / pixel_stride;
531
532 for px in 0..full_pixels {
533 let base = px * pixel_stride;
534 let r = input[base] as f32 / 255.0;
535 let g = input[base + 1] as f32 / 255.0;
536 let b = input[base + 2] as f32 / 255.0;
537
538 let rx = r * nf;
539 let gy = g * nf;
540 let bz = b * nf;
541
542 let r0 = (rx.floor() as usize).min(n - 1);
543 let g0 = (gy.floor() as usize).min(n - 1);
544 let b0 = (bz.floor() as usize).min(n - 1);
545 let r1 = (r0 + 1).min(n - 1);
546 let g1 = (g0 + 1).min(n - 1);
547 let b1 = (b0 + 1).min(n - 1);
548 let fr = rx - r0 as f32;
549 let fg = gy - g0 as f32;
550 let fb = bz - b0 as f32;
551
552 for ch in 0..3 {
553 let lv = |ri: usize, gi: usize, bi: usize| -> f32 {
554 lut[(ri * n * n + gi * n + bi) * 3 + ch]
555 };
556 let c000 = lv(r0, g0, b0);
557 let c100 = lv(r1, g0, b0);
558 let c010 = lv(r0, g1, b0);
559 let c110 = lv(r1, g1, b0);
560 let c001 = lv(r0, g0, b1);
561 let c101 = lv(r1, g0, b1);
562 let c011 = lv(r0, g1, b1);
563 let c111 = lv(r1, g1, b1);
564
565 let c00 = c000 * (1.0 - fr) + c100 * fr;
566 let c01 = c001 * (1.0 - fr) + c101 * fr;
567 let c10 = c010 * (1.0 - fr) + c110 * fr;
568 let c11 = c011 * (1.0 - fr) + c111 * fr;
569 let c0 = c00 * (1.0 - fg) + c10 * fg;
570 let c1 = c01 * (1.0 - fg) + c11 * fg;
571 let val = c0 * (1.0 - fb) + c1 * fb;
572 output[base + ch] = (val.clamp(0.0, 1.0) * 255.0).round() as u8;
573 }
574 }
575
576 for (i, (&inp, &out)) in input.iter().zip(output.iter()).enumerate() {
578 let diff = inp as i32 - out as i32;
579 assert!(
580 diff.abs() <= 2,
581 "channel byte {i}: input={inp}, output={out}"
582 );
583 }
584 }
585
586 #[test]
587 fn test_apply_3d_black_white() {
588 let n = 2usize; let lut = identity_lut_3d(n);
591 let input: Vec<u8> = vec![0, 0, 0, 255, 255, 255];
592 let mut output = vec![0u8; 6];
593
594 let nf = (n - 1) as f32;
595 for px in 0..2usize {
596 let base = px * 3;
597 let r = input[base] as f32 / 255.0;
598 let g = input[base + 1] as f32 / 255.0;
599 let b = input[base + 2] as f32 / 255.0;
600 let rx = r * nf;
601 let gy = g * nf;
602 let bz = b * nf;
603 let r0 = (rx.floor() as usize).min(n - 1);
604 let g0 = (gy.floor() as usize).min(n - 1);
605 let b0 = (bz.floor() as usize).min(n - 1);
606 let r1 = (r0 + 1).min(n - 1);
607 let g1 = (g0 + 1).min(n - 1);
608 let b1 = (b0 + 1).min(n - 1);
609 let fr = rx - r0 as f32;
610 let fg = gy - g0 as f32;
611 let fb = bz - b0 as f32;
612 for ch in 0..3 {
613 let lv = |ri: usize, gi: usize, bi: usize| -> f32 {
614 lut[(ri * n * n + gi * n + bi) * 3 + ch]
615 };
616 let c000 = lv(r0, g0, b0);
617 let c100 = lv(r1, g0, b0);
618 let c010 = lv(r0, g1, b0);
619 let c110 = lv(r1, g1, b0);
620 let c001 = lv(r0, g0, b1);
621 let c101 = lv(r1, g0, b1);
622 let c011 = lv(r0, g1, b1);
623 let c111 = lv(r1, g1, b1);
624 let c00 = c000 * (1.0 - fr) + c100 * fr;
625 let c01 = c001 * (1.0 - fr) + c101 * fr;
626 let c10 = c010 * (1.0 - fr) + c110 * fr;
627 let c11 = c011 * (1.0 - fr) + c111 * fr;
628 let c0 = c00 * (1.0 - fg) + c10 * fg;
629 let c1 = c01 * (1.0 - fg) + c11 * fg;
630 let val = c0 * (1.0 - fb) + c1 * fb;
631 output[base + ch] = (val.clamp(0.0, 1.0) * 255.0).round() as u8;
632 }
633 }
634
635 assert_eq!(&output[0..3], &[0u8, 0, 0]);
637 assert_eq!(&output[3..6], &[255u8, 255, 255]);
639 }
640}