1use crate::{
4 shader::{BindGroupLayoutBuilder, ShaderCompiler, ShaderSource},
5 GpuDevice, GpuError, Result,
6};
7use bytemuck::{Pod, Zeroable};
8use once_cell::sync::OnceCell;
9use wgpu::{BindGroup, BindGroupLayout, ComputePipeline};
10
11use super::utils;
12
13#[repr(C)]
14#[derive(Copy, Clone, Pod, Zeroable)]
15struct FilterParams {
16 width: u32,
17 height: u32,
18 stride: u32,
19 kernel_size: u32,
20 normalize: u32,
21 filter_type: u32,
22 padding: u32,
23 sigma: f32,
24}
25
26pub struct FilterOperation;
28
29impl FilterOperation {
30 #[allow(clippy::too_many_arguments)]
45 pub fn gaussian_blur(
46 device: &GpuDevice,
47 input: &[u8],
48 output: &mut [u8],
49 width: u32,
50 height: u32,
51 sigma: f32,
52 ) -> Result<()> {
53 utils::validate_dimensions(width, height)?;
54 utils::validate_buffer_size(input, width, height, 4)?;
55 utils::validate_buffer_size(output, width, height, 4)?;
56
57 let kernel_size = Self::calculate_kernel_size(sigma);
58 let pipeline = Self::get_gaussian_pipeline(device)?;
59 let layout = Self::get_bind_group_layout(device)?;
60
61 Self::execute_filter(
62 device,
63 pipeline,
64 layout,
65 input,
66 output,
67 width,
68 height,
69 kernel_size,
70 1, sigma,
72 )
73 }
74
75 #[allow(clippy::too_many_arguments)]
90 pub fn sharpen(
91 device: &GpuDevice,
92 input: &[u8],
93 output: &mut [u8],
94 width: u32,
95 height: u32,
96 amount: f32,
97 ) -> Result<()> {
98 utils::validate_dimensions(width, height)?;
99 utils::validate_buffer_size(input, width, height, 4)?;
100 utils::validate_buffer_size(output, width, height, 4)?;
101
102 let pipeline = Self::get_sharpen_pipeline(device)?;
103 let layout = Self::get_bind_group_layout(device)?;
104
105 Self::execute_filter(
106 device, pipeline, layout, input, output, width, height,
107 5, 2, amount,
110 )
111 }
112
113 pub fn edge_detect(
127 device: &GpuDevice,
128 input: &[u8],
129 output: &mut [u8],
130 width: u32,
131 height: u32,
132 ) -> Result<()> {
133 utils::validate_dimensions(width, height)?;
134 utils::validate_buffer_size(input, width, height, 4)?;
135 utils::validate_buffer_size(output, width, height, 4)?;
136
137 let pipeline = Self::get_edge_detect_pipeline(device)?;
138 let layout = Self::get_bind_group_layout(device)?;
139
140 Self::execute_filter(
141 device, pipeline, layout, input, output, width, height, 3, 3, 0.0,
144 )
145 }
146
147 #[allow(clippy::too_many_arguments)]
163 pub fn convolve(
164 device: &GpuDevice,
165 input: &[u8],
166 output: &mut [u8],
167 width: u32,
168 height: u32,
169 kernel: &[f32],
170 normalize: bool,
171 ) -> Result<()> {
172 utils::validate_dimensions(width, height)?;
173 utils::validate_buffer_size(input, width, height, 4)?;
174 utils::validate_buffer_size(output, width, height, 4)?;
175
176 let kernel_size = (kernel.len() as f32).sqrt() as u32;
177 if kernel_size * kernel_size != kernel.len() as u32 {
178 return Err(GpuError::Internal("Kernel must be square".to_string()));
179 }
180 if kernel_size % 2 == 0 {
181 return Err(GpuError::Internal("Kernel size must be odd".to_string()));
182 }
183
184 let pipeline = Self::get_convolve_pipeline(device)?;
185 let layout = Self::get_bind_group_layout_with_kernel(device)?;
186
187 Self::execute_convolve(
188 device,
189 pipeline,
190 layout,
191 input,
192 output,
193 width,
194 height,
195 kernel,
196 kernel_size,
197 normalize,
198 )
199 }
200
201 #[allow(clippy::too_many_arguments)]
202 fn execute_filter(
203 device: &GpuDevice,
204 pipeline: &ComputePipeline,
205 layout: &BindGroupLayout,
206 input: &[u8],
207 output: &mut [u8],
208 width: u32,
209 height: u32,
210 kernel_size: u32,
211 filter_type: u32,
212 sigma: f32,
213 ) -> Result<()> {
214 let input_buffer = utils::create_storage_buffer(device, input.len() as u64)?;
216 let output_buffer = utils::create_storage_buffer(device, output.len() as u64)?;
217
218 device.queue().write_buffer(input_buffer.buffer(), 0, input);
220
221 let params = FilterParams {
223 width,
224 height,
225 stride: width,
226 kernel_size,
227 normalize: 1,
228 filter_type,
229 padding: 0,
230 sigma,
231 };
232 let params_bytes = bytemuck::bytes_of(¶ms);
233 let params_buffer = utils::create_uniform_buffer(device, params_bytes)?;
234
235 let compiler = ShaderCompiler::new(device);
237 let bind_group = compiler.create_bind_group(
238 "Filter Bind Group",
239 layout,
240 &[
241 wgpu::BindGroupEntry {
242 binding: 0,
243 resource: input_buffer.buffer().as_entire_binding(),
244 },
245 wgpu::BindGroupEntry {
246 binding: 1,
247 resource: output_buffer.buffer().as_entire_binding(),
248 },
249 wgpu::BindGroupEntry {
250 binding: 2,
251 resource: params_buffer.buffer().as_entire_binding(),
252 },
253 ],
254 );
255
256 Self::dispatch_compute(device, pipeline, &bind_group, width, height)?;
258
259 let readback_buffer = utils::create_readback_buffer(device, output.len() as u64)?;
261 let mut encoder = device
262 .device()
263 .create_command_encoder(&wgpu::CommandEncoderDescriptor {
264 label: Some("Filter Copy Encoder"),
265 });
266
267 output_buffer.copy_to(&mut encoder, &readback_buffer, 0, 0, output.len() as u64)?;
268
269 device.queue().submit(Some(encoder.finish()));
270 device.wait();
271
272 let result = readback_buffer.read(device, 0, output.len() as u64)?;
273 output.copy_from_slice(&result);
274
275 Ok(())
276 }
277
278 #[allow(clippy::too_many_arguments)]
279 fn execute_convolve(
280 device: &GpuDevice,
281 pipeline: &ComputePipeline,
282 layout: &BindGroupLayout,
283 input: &[u8],
284 output: &mut [u8],
285 width: u32,
286 height: u32,
287 kernel: &[f32],
288 kernel_size: u32,
289 normalize: bool,
290 ) -> Result<()> {
291 let input_buffer = utils::create_storage_buffer(device, input.len() as u64)?;
293 let output_buffer = utils::create_storage_buffer(device, output.len() as u64)?;
294
295 device.queue().write_buffer(input_buffer.buffer(), 0, input);
297
298 let kernel_bytes = bytemuck::cast_slice(kernel);
300 let kernel_buffer = utils::create_storage_buffer(device, kernel_bytes.len() as u64)?;
301 device
302 .queue()
303 .write_buffer(kernel_buffer.buffer(), 0, kernel_bytes);
304
305 let params = FilterParams {
307 width,
308 height,
309 stride: width,
310 kernel_size,
311 normalize: u32::from(normalize),
312 filter_type: 0, padding: 0,
314 sigma: 0.0,
315 };
316 let params_bytes = bytemuck::bytes_of(¶ms);
317 let params_buffer = utils::create_uniform_buffer(device, params_bytes)?;
318
319 let compiler = ShaderCompiler::new(device);
321 let bind_group = compiler.create_bind_group(
322 "Filter Bind Group",
323 layout,
324 &[
325 wgpu::BindGroupEntry {
326 binding: 0,
327 resource: input_buffer.buffer().as_entire_binding(),
328 },
329 wgpu::BindGroupEntry {
330 binding: 1,
331 resource: output_buffer.buffer().as_entire_binding(),
332 },
333 wgpu::BindGroupEntry {
334 binding: 2,
335 resource: params_buffer.buffer().as_entire_binding(),
336 },
337 wgpu::BindGroupEntry {
338 binding: 3,
339 resource: kernel_buffer.buffer().as_entire_binding(),
340 },
341 ],
342 );
343
344 Self::dispatch_compute(device, pipeline, &bind_group, width, height)?;
346
347 let readback_buffer = utils::create_readback_buffer(device, output.len() as u64)?;
349 let mut encoder = device
350 .device()
351 .create_command_encoder(&wgpu::CommandEncoderDescriptor {
352 label: Some("Filter Copy Encoder"),
353 });
354
355 output_buffer.copy_to(&mut encoder, &readback_buffer, 0, 0, output.len() as u64)?;
356
357 device.queue().submit(Some(encoder.finish()));
358 device.wait();
359
360 let result = readback_buffer.read(device, 0, output.len() as u64)?;
361 output.copy_from_slice(&result);
362
363 Ok(())
364 }
365
366 fn dispatch_compute(
367 device: &GpuDevice,
368 pipeline: &ComputePipeline,
369 bind_group: &BindGroup,
370 width: u32,
371 height: u32,
372 ) -> Result<()> {
373 let mut encoder = device
374 .device()
375 .create_command_encoder(&wgpu::CommandEncoderDescriptor {
376 label: Some("Filter Compute Encoder"),
377 });
378
379 {
380 let mut compute_pass = encoder.begin_compute_pass(&wgpu::ComputePassDescriptor {
381 label: Some("Filter Compute Pass"),
382 timestamp_writes: None,
383 });
384
385 compute_pass.set_pipeline(pipeline);
386 compute_pass.set_bind_group(0, bind_group, &[]);
387
388 let (dispatch_x, dispatch_y) = utils::calculate_dispatch_size(width, height, (16, 16));
389 compute_pass.dispatch_workgroups(dispatch_x, dispatch_y, 1);
390 }
391
392 device.queue().submit(Some(encoder.finish()));
393 Ok(())
394 }
395
396 fn calculate_kernel_size(sigma: f32) -> u32 {
397 let radius = (3.0 * sigma).ceil() as u32;
399 2 * radius + 1
400 }
401
402 fn get_bind_group_layout(device: &GpuDevice) -> Result<&'static BindGroupLayout> {
403 static LAYOUT: OnceCell<BindGroupLayout> = OnceCell::new();
404
405 Ok(LAYOUT.get_or_init(|| {
406 let compiler = ShaderCompiler::new(device);
407 let entries = BindGroupLayoutBuilder::new()
408 .add_storage_buffer_read_only(0) .add_storage_buffer(1) .add_uniform_buffer(2) .build();
412
413 compiler.create_bind_group_layout("Filter Bind Group Layout", &entries)
414 }))
415 }
416
417 fn get_bind_group_layout_with_kernel(device: &GpuDevice) -> Result<&'static BindGroupLayout> {
418 static LAYOUT: OnceCell<BindGroupLayout> = OnceCell::new();
419
420 Ok(LAYOUT.get_or_init(|| {
421 let compiler = ShaderCompiler::new(device);
422 let entries = BindGroupLayoutBuilder::new()
423 .add_storage_buffer_read_only(0) .add_storage_buffer(1) .add_uniform_buffer(2) .add_storage_buffer_read_only(3) .build();
428
429 compiler.create_bind_group_layout("Filter Bind Group Layout (with kernel)", &entries)
430 }))
431 }
432
433 fn get_gaussian_pipeline(device: &GpuDevice) -> Result<&'static ComputePipeline> {
434 static PIPELINE: OnceCell<ComputePipeline> = OnceCell::new();
435
436 Ok(PIPELINE.get_or_init(|| {
437 let compiler = ShaderCompiler::new(device);
438 let shader = compiler
439 .compile(
440 "Filter Shader",
441 ShaderSource::Embedded(crate::shader::embedded::FILTER_SHADER),
442 )
443 .expect("Failed to compile filter shader");
444
445 let layout =
446 Self::get_bind_group_layout(device).expect("Failed to create bind group layout");
447
448 compiler
449 .create_pipeline("Gaussian Blur Pipeline", &shader, "convolve_main", layout)
450 .expect("Failed to create pipeline")
451 }))
452 }
453
454 fn get_sharpen_pipeline(device: &GpuDevice) -> Result<&'static ComputePipeline> {
455 static PIPELINE: OnceCell<ComputePipeline> = OnceCell::new();
456
457 Ok(PIPELINE.get_or_init(|| {
458 let compiler = ShaderCompiler::new(device);
459 let shader = compiler
460 .compile(
461 "Filter Shader",
462 ShaderSource::Embedded(crate::shader::embedded::FILTER_SHADER),
463 )
464 .expect("Failed to compile filter shader");
465
466 let layout =
467 Self::get_bind_group_layout(device).expect("Failed to create bind group layout");
468
469 compiler
470 .create_pipeline("Sharpen Pipeline", &shader, "unsharp_mask", layout)
471 .expect("Failed to create pipeline")
472 }))
473 }
474
475 fn get_edge_detect_pipeline(device: &GpuDevice) -> Result<&'static ComputePipeline> {
476 static PIPELINE: OnceCell<ComputePipeline> = OnceCell::new();
477
478 Ok(PIPELINE.get_or_init(|| {
479 let compiler = ShaderCompiler::new(device);
480 let shader = compiler
481 .compile(
482 "Filter Shader",
483 ShaderSource::Embedded(crate::shader::embedded::FILTER_SHADER),
484 )
485 .expect("Failed to compile filter shader");
486
487 let layout =
488 Self::get_bind_group_layout(device).expect("Failed to create bind group layout");
489
490 compiler
491 .create_pipeline("Edge Detect Pipeline", &shader, "edge_detect", layout)
492 .expect("Failed to create pipeline")
493 }))
494 }
495
496 fn get_convolve_pipeline(device: &GpuDevice) -> Result<&'static ComputePipeline> {
497 static PIPELINE: OnceCell<ComputePipeline> = OnceCell::new();
498
499 Ok(PIPELINE.get_or_init(|| {
500 let compiler = ShaderCompiler::new(device);
501 let shader = compiler
502 .compile(
503 "Filter Shader",
504 ShaderSource::Embedded(crate::shader::embedded::FILTER_SHADER),
505 )
506 .expect("Failed to compile filter shader");
507
508 let layout = Self::get_bind_group_layout_with_kernel(device)
509 .expect("Failed to create bind group layout");
510
511 compiler
512 .create_pipeline("Convolve Pipeline", &shader, "convolve_main", layout)
513 .expect("Failed to create pipeline")
514 }))
515 }
516}