1use crate::{
4 shader::{BindGroupLayoutBuilder, ShaderCompiler, ShaderSource},
5 GpuDevice, GpuError, Result,
6};
7use bytemuck::{Pod, Zeroable};
8use once_cell::sync::OnceCell;
9use wgpu::{BindGroup, BindGroupLayout, ComputePipeline};
10
11use super::utils;
12
13#[repr(C)]
14#[derive(Copy, Clone, Pod, Zeroable)]
15struct TransformParams {
16 width: u32,
17 height: u32,
18 block_size: u32,
19 transform_type: u32,
20 stride: u32,
21 is_inverse: u32,
22 padding1: u32,
23 padding2: u32,
24}
25
26pub struct TransformOperation;
28
29impl TransformOperation {
30 pub fn dct_2d(
47 device: &GpuDevice,
48 input: &[f32],
49 output: &mut [f32],
50 width: u32,
51 height: u32,
52 ) -> Result<()> {
53 if width % 8 != 0 || height % 8 != 0 {
54 return Err(GpuError::InvalidDimensions { width, height });
55 }
56
57 utils::validate_dimensions(width, height)?;
58
59 let expected_size = (width * height) as usize;
60 if input.len() < expected_size || output.len() < expected_size {
61 return Err(GpuError::InvalidBufferSize {
62 expected: expected_size,
63 actual: input.len().min(output.len()),
64 });
65 }
66
67 let pipeline = Self::get_dct_8x8_pipeline(device)?;
68 let layout = Self::get_bind_group_layout(device)?;
69
70 Self::execute_transform(
71 device, pipeline, layout, input, output, width, height, 8, 0, )
73 }
74
75 pub fn idct_2d(
92 device: &GpuDevice,
93 input: &[f32],
94 output: &mut [f32],
95 width: u32,
96 height: u32,
97 ) -> Result<()> {
98 if width % 8 != 0 || height % 8 != 0 {
99 return Err(GpuError::InvalidDimensions { width, height });
100 }
101
102 utils::validate_dimensions(width, height)?;
103
104 let expected_size = (width * height) as usize;
105 if input.len() < expected_size || output.len() < expected_size {
106 return Err(GpuError::InvalidBufferSize {
107 expected: expected_size,
108 actual: input.len().min(output.len()),
109 });
110 }
111
112 let pipeline = Self::get_idct_8x8_pipeline(device)?;
113 let layout = Self::get_bind_group_layout(device)?;
114
115 Self::execute_transform(
116 device, pipeline, layout, input, output, width, height, 8, 1, )
118 }
119
120 pub fn dct_2d_general(
136 device: &GpuDevice,
137 input: &[f32],
138 output: &mut [f32],
139 width: u32,
140 height: u32,
141 ) -> Result<()> {
142 utils::validate_dimensions(width, height)?;
143
144 let expected_size = (width * height) as usize;
145 if input.len() < expected_size || output.len() < expected_size {
146 return Err(GpuError::InvalidBufferSize {
147 expected: expected_size,
148 actual: input.len().min(output.len()),
149 });
150 }
151
152 let mut temp = vec![0.0f32; expected_size];
154
155 let row_pipeline = Self::get_dct_row_pipeline(device)?;
157 let layout = Self::get_bind_group_layout(device)?;
158
159 Self::execute_transform(
160 device,
161 row_pipeline,
162 layout,
163 input,
164 &mut temp,
165 width,
166 height,
167 width,
168 0,
169 )?;
170
171 let col_pipeline = Self::get_dct_col_pipeline(device)?;
173
174 Self::execute_transform(
175 device,
176 col_pipeline,
177 layout,
178 &temp,
179 output,
180 width,
181 height,
182 height,
183 0,
184 )
185 }
186
187 #[allow(clippy::too_many_arguments)]
188 fn execute_transform(
189 device: &GpuDevice,
190 pipeline: &ComputePipeline,
191 layout: &BindGroupLayout,
192 input: &[f32],
193 output: &mut [f32],
194 width: u32,
195 height: u32,
196 block_size: u32,
197 transform_type: u32,
198 ) -> Result<()> {
199 let input_bytes = bytemuck::cast_slice(input);
200 let output_size = std::mem::size_of_val(output);
201
202 let input_buffer = utils::create_storage_buffer(device, input_bytes.len() as u64)?;
204 let output_buffer = utils::create_storage_buffer(device, output_size as u64)?;
205
206 device
208 .queue()
209 .write_buffer(input_buffer.buffer(), 0, input_bytes);
210
211 let params = TransformParams {
213 width,
214 height,
215 block_size,
216 transform_type,
217 stride: width,
218 is_inverse: 0,
219 padding1: 0,
220 padding2: 0,
221 };
222 let params_bytes = bytemuck::bytes_of(¶ms);
223 let params_buffer = utils::create_uniform_buffer(device, params_bytes)?;
224
225 let compiler = ShaderCompiler::new(device);
227 let bind_group = compiler.create_bind_group(
228 "Transform Bind Group",
229 layout,
230 &[
231 wgpu::BindGroupEntry {
232 binding: 0,
233 resource: input_buffer.buffer().as_entire_binding(),
234 },
235 wgpu::BindGroupEntry {
236 binding: 1,
237 resource: output_buffer.buffer().as_entire_binding(),
238 },
239 wgpu::BindGroupEntry {
240 binding: 2,
241 resource: params_buffer.buffer().as_entire_binding(),
242 },
243 ],
244 );
245
246 Self::dispatch_compute(device, pipeline, &bind_group, width, height, block_size)?;
248
249 let readback_buffer = utils::create_readback_buffer(device, output_size as u64)?;
251 let mut encoder = device
252 .device()
253 .create_command_encoder(&wgpu::CommandEncoderDescriptor {
254 label: Some("Transform Copy Encoder"),
255 });
256
257 output_buffer.copy_to(&mut encoder, &readback_buffer, 0, 0, output_size as u64)?;
258
259 device.queue().submit(Some(encoder.finish()));
260 device.wait();
261
262 let result = readback_buffer.read(device, 0, output_size as u64)?;
263 let result_f32: &[f32] = bytemuck::cast_slice(&result);
264 output.copy_from_slice(result_f32);
265
266 Ok(())
267 }
268
269 fn dispatch_compute(
270 device: &GpuDevice,
271 pipeline: &ComputePipeline,
272 bind_group: &BindGroup,
273 width: u32,
274 height: u32,
275 block_size: u32,
276 ) -> Result<()> {
277 let mut encoder = device
278 .device()
279 .create_command_encoder(&wgpu::CommandEncoderDescriptor {
280 label: Some("Transform Compute Encoder"),
281 });
282
283 {
284 let mut compute_pass = encoder.begin_compute_pass(&wgpu::ComputePassDescriptor {
285 label: Some("Transform Compute Pass"),
286 timestamp_writes: None,
287 });
288
289 compute_pass.set_pipeline(pipeline);
290 compute_pass.set_bind_group(0, bind_group, &[]);
291
292 if block_size == 8 {
293 let dispatch_x = width / 8;
295 let dispatch_y = height / 8;
296 compute_pass.dispatch_workgroups(dispatch_x, dispatch_y, 1);
297 } else {
298 let total_elements = width * height;
300 let dispatch = total_elements.div_ceil(256);
301 compute_pass.dispatch_workgroups(dispatch, 1, 1);
302 }
303 }
304
305 device.queue().submit(Some(encoder.finish()));
306 Ok(())
307 }
308
309 fn get_bind_group_layout(device: &GpuDevice) -> Result<&'static BindGroupLayout> {
310 static LAYOUT: OnceCell<BindGroupLayout> = OnceCell::new();
311
312 Ok(LAYOUT.get_or_init(|| {
313 let compiler = ShaderCompiler::new(device);
314 let entries = BindGroupLayoutBuilder::new()
315 .add_storage_buffer_read_only(0) .add_storage_buffer(1) .add_uniform_buffer(2) .build();
319
320 compiler.create_bind_group_layout("Transform Bind Group Layout", &entries)
321 }))
322 }
323
324 fn get_dct_8x8_pipeline(device: &GpuDevice) -> Result<&'static ComputePipeline> {
325 static PIPELINE: OnceCell<ComputePipeline> = OnceCell::new();
326
327 Ok(PIPELINE.get_or_init(|| {
328 let compiler = ShaderCompiler::new(device);
329 let shader = compiler
330 .compile(
331 "Transform Shader",
332 ShaderSource::Embedded(crate::shader::embedded::TRANSFORM_SHADER),
333 )
334 .expect("Failed to compile transform shader");
335
336 let layout =
337 Self::get_bind_group_layout(device).expect("Failed to create bind group layout");
338
339 compiler
340 .create_pipeline("DCT 8x8 Pipeline", &shader, "dct_8x8", layout)
341 .expect("Failed to create pipeline")
342 }))
343 }
344
345 fn get_idct_8x8_pipeline(device: &GpuDevice) -> Result<&'static ComputePipeline> {
346 static PIPELINE: OnceCell<ComputePipeline> = OnceCell::new();
347
348 Ok(PIPELINE.get_or_init(|| {
349 let compiler = ShaderCompiler::new(device);
350 let shader = compiler
351 .compile(
352 "Transform Shader",
353 ShaderSource::Embedded(crate::shader::embedded::TRANSFORM_SHADER),
354 )
355 .expect("Failed to compile transform shader");
356
357 let layout =
358 Self::get_bind_group_layout(device).expect("Failed to create bind group layout");
359
360 compiler
361 .create_pipeline("IDCT 8x8 Pipeline", &shader, "idct_8x8", layout)
362 .expect("Failed to create pipeline")
363 }))
364 }
365
366 fn get_dct_row_pipeline(device: &GpuDevice) -> Result<&'static ComputePipeline> {
367 static PIPELINE: OnceCell<ComputePipeline> = OnceCell::new();
368
369 Ok(PIPELINE.get_or_init(|| {
370 let compiler = ShaderCompiler::new(device);
371 let shader = compiler
372 .compile(
373 "Transform Shader",
374 ShaderSource::Embedded(crate::shader::embedded::TRANSFORM_SHADER),
375 )
376 .expect("Failed to compile transform shader");
377
378 let layout =
379 Self::get_bind_group_layout(device).expect("Failed to create bind group layout");
380
381 compiler
382 .create_pipeline("DCT Row Pipeline", &shader, "dct_row", layout)
383 .expect("Failed to create pipeline")
384 }))
385 }
386
387 fn get_dct_col_pipeline(device: &GpuDevice) -> Result<&'static ComputePipeline> {
388 static PIPELINE: OnceCell<ComputePipeline> = OnceCell::new();
389
390 Ok(PIPELINE.get_or_init(|| {
391 let compiler = ShaderCompiler::new(device);
392 let shader = compiler
393 .compile(
394 "Transform Shader",
395 ShaderSource::Embedded(crate::shader::embedded::TRANSFORM_SHADER),
396 )
397 .expect("Failed to compile transform shader");
398
399 let layout =
400 Self::get_bind_group_layout(device).expect("Failed to create bind group layout");
401
402 compiler
403 .create_pipeline("DCT Column Pipeline", &shader, "dct_col", layout)
404 .expect("Failed to create pipeline")
405 }))
406 }
407}