Skip to main content

oximedia_gpu/ops/
transform.rs

1//! Transform operations (DCT, FFT) for frequency domain processing
2
3use crate::{
4    shader::{BindGroupLayoutBuilder, ShaderCompiler, ShaderSource},
5    GpuDevice, GpuError, Result,
6};
7use bytemuck::{Pod, Zeroable};
8use once_cell::sync::OnceCell;
9use wgpu::{BindGroup, BindGroupLayout, ComputePipeline};
10
11use super::utils;
12
13#[repr(C)]
14#[derive(Copy, Clone, Pod, Zeroable)]
15struct TransformParams {
16    width: u32,
17    height: u32,
18    block_size: u32,
19    transform_type: u32,
20    stride: u32,
21    is_inverse: u32,
22    padding1: u32,
23    padding2: u32,
24}
25
26/// Transform operations for frequency domain processing
27pub struct TransformOperation;
28
29impl TransformOperation {
30    /// Compute 2D DCT (Discrete Cosine Transform)
31    ///
32    /// Computes the forward DCT on 8x8 blocks. Input dimensions must be
33    /// multiples of 8.
34    ///
35    /// # Arguments
36    ///
37    /// * `device` - GPU device
38    /// * `input` - Input data (f32 values)
39    /// * `output` - Output DCT coefficients
40    /// * `width` - Data width (must be multiple of 8)
41    /// * `height` - Data height (must be multiple of 8)
42    ///
43    /// # Errors
44    ///
45    /// Returns an error if dimensions are invalid or if the GPU operation fails.
46    pub fn dct_2d(
47        device: &GpuDevice,
48        input: &[f32],
49        output: &mut [f32],
50        width: u32,
51        height: u32,
52    ) -> Result<()> {
53        if width % 8 != 0 || height % 8 != 0 {
54            return Err(GpuError::InvalidDimensions { width, height });
55        }
56
57        utils::validate_dimensions(width, height)?;
58
59        let expected_size = (width * height) as usize;
60        if input.len() < expected_size || output.len() < expected_size {
61            return Err(GpuError::InvalidBufferSize {
62                expected: expected_size,
63                actual: input.len().min(output.len()),
64            });
65        }
66
67        let pipeline = Self::get_dct_8x8_pipeline(device)?;
68        let layout = Self::get_bind_group_layout(device)?;
69
70        Self::execute_transform(
71            device, pipeline, layout, input, output, width, height, 8, 0, // DCT
72        )
73    }
74
75    /// Compute 2D IDCT (Inverse Discrete Cosine Transform)
76    ///
77    /// Computes the inverse DCT on 8x8 blocks. Input dimensions must be
78    /// multiples of 8.
79    ///
80    /// # Arguments
81    ///
82    /// * `device` - GPU device
83    /// * `input` - Input DCT coefficients
84    /// * `output` - Output reconstructed data
85    /// * `width` - Data width (must be multiple of 8)
86    /// * `height` - Data height (must be multiple of 8)
87    ///
88    /// # Errors
89    ///
90    /// Returns an error if dimensions are invalid or if the GPU operation fails.
91    pub fn idct_2d(
92        device: &GpuDevice,
93        input: &[f32],
94        output: &mut [f32],
95        width: u32,
96        height: u32,
97    ) -> Result<()> {
98        if width % 8 != 0 || height % 8 != 0 {
99            return Err(GpuError::InvalidDimensions { width, height });
100        }
101
102        utils::validate_dimensions(width, height)?;
103
104        let expected_size = (width * height) as usize;
105        if input.len() < expected_size || output.len() < expected_size {
106            return Err(GpuError::InvalidBufferSize {
107                expected: expected_size,
108                actual: input.len().min(output.len()),
109            });
110        }
111
112        let pipeline = Self::get_idct_8x8_pipeline(device)?;
113        let layout = Self::get_bind_group_layout(device)?;
114
115        Self::execute_transform(
116            device, pipeline, layout, input, output, width, height, 8, 1, // IDCT
117        )
118    }
119
120    /// Compute general 2D DCT using row-column decomposition
121    ///
122    /// This method works for any dimensions, not just multiples of 8.
123    ///
124    /// # Arguments
125    ///
126    /// * `device` - GPU device
127    /// * `input` - Input data (f32 values)
128    /// * `output` - Output DCT coefficients
129    /// * `width` - Data width
130    /// * `height` - Data height
131    ///
132    /// # Errors
133    ///
134    /// Returns an error if dimensions are invalid or if the GPU operation fails.
135    pub fn dct_2d_general(
136        device: &GpuDevice,
137        input: &[f32],
138        output: &mut [f32],
139        width: u32,
140        height: u32,
141    ) -> Result<()> {
142        utils::validate_dimensions(width, height)?;
143
144        let expected_size = (width * height) as usize;
145        if input.len() < expected_size || output.len() < expected_size {
146            return Err(GpuError::InvalidBufferSize {
147                expected: expected_size,
148                actual: input.len().min(output.len()),
149            });
150        }
151
152        // Two-pass DCT: row then column
153        let mut temp = vec![0.0f32; expected_size];
154
155        // Row DCT
156        let row_pipeline = Self::get_dct_row_pipeline(device)?;
157        let layout = Self::get_bind_group_layout(device)?;
158
159        Self::execute_transform(
160            device,
161            row_pipeline,
162            layout,
163            input,
164            &mut temp,
165            width,
166            height,
167            width,
168            0,
169        )?;
170
171        // Column DCT
172        let col_pipeline = Self::get_dct_col_pipeline(device)?;
173
174        Self::execute_transform(
175            device,
176            col_pipeline,
177            layout,
178            &temp,
179            output,
180            width,
181            height,
182            height,
183            0,
184        )
185    }
186
187    #[allow(clippy::too_many_arguments)]
188    fn execute_transform(
189        device: &GpuDevice,
190        pipeline: &ComputePipeline,
191        layout: &BindGroupLayout,
192        input: &[f32],
193        output: &mut [f32],
194        width: u32,
195        height: u32,
196        block_size: u32,
197        transform_type: u32,
198    ) -> Result<()> {
199        let input_bytes = bytemuck::cast_slice(input);
200        let output_size = std::mem::size_of_val(output);
201
202        // Create buffers
203        let input_buffer = utils::create_storage_buffer(device, input_bytes.len() as u64)?;
204        let output_buffer = utils::create_storage_buffer(device, output_size as u64)?;
205
206        // Upload input data
207        device
208            .queue()
209            .write_buffer(input_buffer.buffer(), 0, input_bytes);
210
211        // Create uniform buffer for parameters
212        let params = TransformParams {
213            width,
214            height,
215            block_size,
216            transform_type,
217            stride: width,
218            is_inverse: 0,
219            padding1: 0,
220            padding2: 0,
221        };
222        let params_bytes = bytemuck::bytes_of(&params);
223        let params_buffer = utils::create_uniform_buffer(device, params_bytes)?;
224
225        // Create bind group
226        let compiler = ShaderCompiler::new(device);
227        let bind_group = compiler.create_bind_group(
228            "Transform Bind Group",
229            layout,
230            &[
231                wgpu::BindGroupEntry {
232                    binding: 0,
233                    resource: input_buffer.buffer().as_entire_binding(),
234                },
235                wgpu::BindGroupEntry {
236                    binding: 1,
237                    resource: output_buffer.buffer().as_entire_binding(),
238                },
239                wgpu::BindGroupEntry {
240                    binding: 2,
241                    resource: params_buffer.buffer().as_entire_binding(),
242                },
243            ],
244        );
245
246        // Execute compute pass
247        Self::dispatch_compute(device, pipeline, &bind_group, width, height, block_size)?;
248
249        // Read back results
250        let readback_buffer = utils::create_readback_buffer(device, output_size as u64)?;
251        let mut encoder = device
252            .device()
253            .create_command_encoder(&wgpu::CommandEncoderDescriptor {
254                label: Some("Transform Copy Encoder"),
255            });
256
257        output_buffer.copy_to(&mut encoder, &readback_buffer, 0, 0, output_size as u64)?;
258
259        device.queue().submit(Some(encoder.finish()));
260        device.wait();
261
262        let result = readback_buffer.read(device, 0, output_size as u64)?;
263        let result_f32: &[f32] = bytemuck::cast_slice(&result);
264        output.copy_from_slice(result_f32);
265
266        Ok(())
267    }
268
269    fn dispatch_compute(
270        device: &GpuDevice,
271        pipeline: &ComputePipeline,
272        bind_group: &BindGroup,
273        width: u32,
274        height: u32,
275        block_size: u32,
276    ) -> Result<()> {
277        let mut encoder = device
278            .device()
279            .create_command_encoder(&wgpu::CommandEncoderDescriptor {
280                label: Some("Transform Compute Encoder"),
281            });
282
283        {
284            let mut compute_pass = encoder.begin_compute_pass(&wgpu::ComputePassDescriptor {
285                label: Some("Transform Compute Pass"),
286                timestamp_writes: None,
287            });
288
289            compute_pass.set_pipeline(pipeline);
290            compute_pass.set_bind_group(0, bind_group, &[]);
291
292            if block_size == 8 {
293                // For 8x8 DCT, dispatch one workgroup per block
294                let dispatch_x = width / 8;
295                let dispatch_y = height / 8;
296                compute_pass.dispatch_workgroups(dispatch_x, dispatch_y, 1);
297            } else {
298                // For row/column transforms
299                let total_elements = width * height;
300                let dispatch = total_elements.div_ceil(256);
301                compute_pass.dispatch_workgroups(dispatch, 1, 1);
302            }
303        }
304
305        device.queue().submit(Some(encoder.finish()));
306        Ok(())
307    }
308
309    fn get_bind_group_layout(device: &GpuDevice) -> Result<&'static BindGroupLayout> {
310        static LAYOUT: OnceCell<BindGroupLayout> = OnceCell::new();
311
312        Ok(LAYOUT.get_or_init(|| {
313            let compiler = ShaderCompiler::new(device);
314            let entries = BindGroupLayoutBuilder::new()
315                .add_storage_buffer_read_only(0) // input
316                .add_storage_buffer(1) // output
317                .add_uniform_buffer(2) // params
318                .build();
319
320            compiler.create_bind_group_layout("Transform Bind Group Layout", &entries)
321        }))
322    }
323
324    fn get_dct_8x8_pipeline(device: &GpuDevice) -> Result<&'static ComputePipeline> {
325        static PIPELINE: OnceCell<ComputePipeline> = OnceCell::new();
326
327        Ok(PIPELINE.get_or_init(|| {
328            let compiler = ShaderCompiler::new(device);
329            let shader = compiler
330                .compile(
331                    "Transform Shader",
332                    ShaderSource::Embedded(crate::shader::embedded::TRANSFORM_SHADER),
333                )
334                .expect("Failed to compile transform shader");
335
336            let layout =
337                Self::get_bind_group_layout(device).expect("Failed to create bind group layout");
338
339            compiler
340                .create_pipeline("DCT 8x8 Pipeline", &shader, "dct_8x8", layout)
341                .expect("Failed to create pipeline")
342        }))
343    }
344
345    fn get_idct_8x8_pipeline(device: &GpuDevice) -> Result<&'static ComputePipeline> {
346        static PIPELINE: OnceCell<ComputePipeline> = OnceCell::new();
347
348        Ok(PIPELINE.get_or_init(|| {
349            let compiler = ShaderCompiler::new(device);
350            let shader = compiler
351                .compile(
352                    "Transform Shader",
353                    ShaderSource::Embedded(crate::shader::embedded::TRANSFORM_SHADER),
354                )
355                .expect("Failed to compile transform shader");
356
357            let layout =
358                Self::get_bind_group_layout(device).expect("Failed to create bind group layout");
359
360            compiler
361                .create_pipeline("IDCT 8x8 Pipeline", &shader, "idct_8x8", layout)
362                .expect("Failed to create pipeline")
363        }))
364    }
365
366    fn get_dct_row_pipeline(device: &GpuDevice) -> Result<&'static ComputePipeline> {
367        static PIPELINE: OnceCell<ComputePipeline> = OnceCell::new();
368
369        Ok(PIPELINE.get_or_init(|| {
370            let compiler = ShaderCompiler::new(device);
371            let shader = compiler
372                .compile(
373                    "Transform Shader",
374                    ShaderSource::Embedded(crate::shader::embedded::TRANSFORM_SHADER),
375                )
376                .expect("Failed to compile transform shader");
377
378            let layout =
379                Self::get_bind_group_layout(device).expect("Failed to create bind group layout");
380
381            compiler
382                .create_pipeline("DCT Row Pipeline", &shader, "dct_row", layout)
383                .expect("Failed to create pipeline")
384        }))
385    }
386
387    fn get_dct_col_pipeline(device: &GpuDevice) -> Result<&'static ComputePipeline> {
388        static PIPELINE: OnceCell<ComputePipeline> = OnceCell::new();
389
390        Ok(PIPELINE.get_or_init(|| {
391            let compiler = ShaderCompiler::new(device);
392            let shader = compiler
393                .compile(
394                    "Transform Shader",
395                    ShaderSource::Embedded(crate::shader::embedded::TRANSFORM_SHADER),
396                )
397                .expect("Failed to compile transform shader");
398
399            let layout =
400                Self::get_bind_group_layout(device).expect("Failed to create bind group layout");
401
402            compiler
403                .create_pipeline("DCT Column Pipeline", &shader, "dct_col", layout)
404                .expect("Failed to create pipeline")
405        }))
406    }
407}