Skip to main content

oximedia_gpu/ops/
colorspace.rs

1//! Color space conversion operations (RGB ↔ YUV)
2
3use crate::{
4    shader::{BindGroupLayoutBuilder, ShaderCompiler, ShaderSource},
5    GpuDevice, Result,
6};
7use bytemuck::{Pod, Zeroable};
8use once_cell::sync::OnceCell;
9use wgpu::{BindGroup, BindGroupLayout, ComputePipeline};
10
11use super::utils;
12
13/// Color space standards
14#[derive(Debug, Clone, Copy, PartialEq, Eq)]
15pub enum ColorSpace {
16    /// BT.601 (SD video)
17    BT601,
18    /// BT.709 (HD video)
19    BT709,
20    /// BT.2020 (UHD video)
21    BT2020,
22}
23
24impl ColorSpace {
25    fn to_format_id(self) -> u32 {
26        match self {
27            Self::BT601 => 0,
28            Self::BT709 => 1,
29            Self::BT2020 => 2,
30        }
31    }
32}
33
34#[repr(C)]
35#[derive(Copy, Clone, Pod, Zeroable)]
36struct ConversionParams {
37    width: u32,
38    height: u32,
39    stride: u32,
40    format: u32,
41}
42
43/// Color space conversion operations
44pub struct ColorSpaceConversion;
45
46impl ColorSpaceConversion {
47    /// Convert RGB to YUV
48    ///
49    /// # Arguments
50    ///
51    /// * `device` - GPU device
52    /// * `input` - Input RGB buffer (packed RGBA format)
53    /// * `output` - Output YUV buffer (packed YUVA format)
54    /// * `width` - Image width
55    /// * `height` - Image height
56    /// * `color_space` - Color space standard (BT.601, BT.709, BT.2020)
57    ///
58    /// # Errors
59    ///
60    /// Returns an error if buffer sizes are invalid or if the GPU operation fails.
61    #[allow(clippy::too_many_arguments)]
62    pub fn rgb_to_yuv(
63        device: &GpuDevice,
64        input: &[u8],
65        output: &mut [u8],
66        width: u32,
67        height: u32,
68        color_space: ColorSpace,
69    ) -> Result<()> {
70        utils::validate_dimensions(width, height)?;
71        utils::validate_buffer_size(input, width, height, 4)?;
72        utils::validate_buffer_size(output, width, height, 4)?;
73
74        let pipeline = Self::get_rgb_to_yuv_pipeline(device)?;
75        let layout = Self::get_bind_group_layout(device)?;
76
77        Self::execute_conversion(
78            device,
79            pipeline,
80            layout,
81            input,
82            output,
83            width,
84            height,
85            color_space,
86        )
87    }
88
89    /// Convert YUV to RGB
90    ///
91    /// # Arguments
92    ///
93    /// * `device` - GPU device
94    /// * `input` - Input YUV buffer (packed YUVA format)
95    /// * `output` - Output RGB buffer (packed RGBA format)
96    /// * `width` - Image width
97    /// * `height` - Image height
98    /// * `color_space` - Color space standard (BT.601, BT.709, BT.2020)
99    ///
100    /// # Errors
101    ///
102    /// Returns an error if buffer sizes are invalid or if the GPU operation fails.
103    #[allow(clippy::too_many_arguments)]
104    pub fn yuv_to_rgb(
105        device: &GpuDevice,
106        input: &[u8],
107        output: &mut [u8],
108        width: u32,
109        height: u32,
110        color_space: ColorSpace,
111    ) -> Result<()> {
112        utils::validate_dimensions(width, height)?;
113        utils::validate_buffer_size(input, width, height, 4)?;
114        utils::validate_buffer_size(output, width, height, 4)?;
115
116        let pipeline = Self::get_yuv_to_rgb_pipeline(device)?;
117        let layout = Self::get_bind_group_layout(device)?;
118
119        Self::execute_conversion(
120            device,
121            pipeline,
122            layout,
123            input,
124            output,
125            width,
126            height,
127            color_space,
128        )
129    }
130
131    #[allow(clippy::too_many_arguments)]
132    fn execute_conversion(
133        device: &GpuDevice,
134        pipeline: &ComputePipeline,
135        layout: &BindGroupLayout,
136        input: &[u8],
137        output: &mut [u8],
138        width: u32,
139        height: u32,
140        color_space: ColorSpace,
141    ) -> Result<()> {
142        // Create buffers
143        let input_buffer = utils::create_storage_buffer(device, input.len() as u64)?;
144        let output_buffer = utils::create_storage_buffer(device, output.len() as u64)?;
145
146        // Upload input data
147        device.queue().write_buffer(input_buffer.buffer(), 0, input);
148
149        // Create uniform buffer for parameters
150        let params = ConversionParams {
151            width,
152            height,
153            stride: width,
154            format: color_space.to_format_id(),
155        };
156        let params_bytes = bytemuck::bytes_of(&params);
157        let params_buffer = utils::create_uniform_buffer(device, params_bytes)?;
158
159        // Create bind group
160        let compiler = ShaderCompiler::new(device);
161        let bind_group = compiler.create_bind_group(
162            "ColorSpace Bind Group",
163            layout,
164            &[
165                wgpu::BindGroupEntry {
166                    binding: 0,
167                    resource: input_buffer.buffer().as_entire_binding(),
168                },
169                wgpu::BindGroupEntry {
170                    binding: 1,
171                    resource: output_buffer.buffer().as_entire_binding(),
172                },
173                wgpu::BindGroupEntry {
174                    binding: 2,
175                    resource: params_buffer.buffer().as_entire_binding(),
176                },
177            ],
178        );
179
180        // Execute compute pass
181        Self::dispatch_compute(device, pipeline, &bind_group, width, height)?;
182
183        // Read back results
184        let readback_buffer = utils::create_readback_buffer(device, output.len() as u64)?;
185        let mut encoder = device
186            .device()
187            .create_command_encoder(&wgpu::CommandEncoderDescriptor {
188                label: Some("ColorSpace Copy Encoder"),
189            });
190
191        output_buffer.copy_to(&mut encoder, &readback_buffer, 0, 0, output.len() as u64)?;
192
193        device.queue().submit(Some(encoder.finish()));
194        device.wait();
195
196        let result = readback_buffer.read(device, 0, output.len() as u64)?;
197        output.copy_from_slice(&result);
198
199        Ok(())
200    }
201
202    fn dispatch_compute(
203        device: &GpuDevice,
204        pipeline: &ComputePipeline,
205        bind_group: &BindGroup,
206        width: u32,
207        height: u32,
208    ) -> Result<()> {
209        let mut encoder = device
210            .device()
211            .create_command_encoder(&wgpu::CommandEncoderDescriptor {
212                label: Some("ColorSpace Compute Encoder"),
213            });
214
215        {
216            let mut compute_pass = encoder.begin_compute_pass(&wgpu::ComputePassDescriptor {
217                label: Some("ColorSpace Compute Pass"),
218                timestamp_writes: None,
219            });
220
221            compute_pass.set_pipeline(pipeline);
222            compute_pass.set_bind_group(0, bind_group, &[]);
223
224            let (dispatch_x, dispatch_y) = utils::calculate_dispatch_size(width, height, (16, 16));
225            compute_pass.dispatch_workgroups(dispatch_x, dispatch_y, 1);
226        }
227
228        device.queue().submit(Some(encoder.finish()));
229        Ok(())
230    }
231
232    fn get_bind_group_layout(device: &GpuDevice) -> Result<&'static BindGroupLayout> {
233        static LAYOUT: OnceCell<BindGroupLayout> = OnceCell::new();
234
235        Ok(LAYOUT.get_or_init(|| {
236            let compiler = ShaderCompiler::new(device);
237            let entries = BindGroupLayoutBuilder::new()
238                .add_storage_buffer_read_only(0) // input
239                .add_storage_buffer(1) // output
240                .add_uniform_buffer(2) // params
241                .build();
242
243            compiler.create_bind_group_layout("ColorSpace Bind Group Layout", &entries)
244        }))
245    }
246
247    fn get_rgb_to_yuv_pipeline(device: &GpuDevice) -> Result<&'static ComputePipeline> {
248        static PIPELINE: OnceCell<ComputePipeline> = OnceCell::new();
249
250        Ok(PIPELINE.get_or_init(|| {
251            let compiler = ShaderCompiler::new(device);
252            let shader = compiler
253                .compile(
254                    "ColorSpace Shader",
255                    ShaderSource::Embedded(crate::shader::embedded::COLORSPACE_SHADER),
256                )
257                .expect("Failed to compile colorspace shader");
258
259            let layout =
260                Self::get_bind_group_layout(device).expect("Failed to create bind group layout");
261
262            compiler
263                .create_pipeline("RGB to YUV Pipeline", &shader, "rgb_to_yuv_main", layout)
264                .expect("Failed to create pipeline")
265        }))
266    }
267
268    fn get_yuv_to_rgb_pipeline(device: &GpuDevice) -> Result<&'static ComputePipeline> {
269        static PIPELINE: OnceCell<ComputePipeline> = OnceCell::new();
270
271        Ok(PIPELINE.get_or_init(|| {
272            let compiler = ShaderCompiler::new(device);
273            let shader = compiler
274                .compile(
275                    "ColorSpace Shader",
276                    ShaderSource::Embedded(crate::shader::embedded::COLORSPACE_SHADER),
277                )
278                .expect("Failed to compile colorspace shader");
279
280            let layout =
281                Self::get_bind_group_layout(device).expect("Failed to create bind group layout");
282
283            compiler
284                .create_pipeline("YUV to RGB Pipeline", &shader, "yuv_to_rgb_main", layout)
285                .expect("Failed to create pipeline")
286        }))
287    }
288}