1use crate::buffer::GpuBuffer;
7use crate::context::GpuContext;
8use crate::error::{GpuError, GpuResult};
9use crate::shaders::{
10 ComputePipelineBuilder, WgslShader, create_compute_bind_group_layout, storage_buffer_layout,
11 uniform_buffer_layout,
12};
13use bytemuck::{Pod, Zeroable};
14use tracing::debug;
15use wgpu::{
16 BindGroupDescriptor, BindGroupEntry, BufferUsages, CommandEncoderDescriptor,
17 ComputePassDescriptor, ComputePipeline,
18};
19
20#[derive(Debug, Clone, Copy, Pod, Zeroable)]
22#[repr(C)]
23pub struct ConvolutionParams {
24 pub width: u32,
26 pub height: u32,
28 pub kernel_width: u32,
30 pub kernel_height: u32,
32}
33
34impl ConvolutionParams {
35 pub fn new(width: u32, height: u32, kernel_width: u32, kernel_height: u32) -> GpuResult<Self> {
37 if kernel_width % 2 == 0 || kernel_height % 2 == 0 {
38 return Err(GpuError::invalid_kernel_params(
39 "Kernel dimensions must be odd",
40 ));
41 }
42
43 Ok(Self {
44 width,
45 height,
46 kernel_width,
47 kernel_height,
48 })
49 }
50
51 pub fn square(width: u32, height: u32, kernel_size: u32) -> GpuResult<Self> {
53 Self::new(width, height, kernel_size, kernel_size)
54 }
55
56 pub fn kernel_center(&self) -> (u32, u32) {
58 (self.kernel_width / 2, self.kernel_height / 2)
59 }
60}
61
62pub struct ConvolutionKernel {
64 context: GpuContext,
65 pipeline: ComputePipeline,
66 bind_group_layout: wgpu::BindGroupLayout,
67 workgroup_size: (u32, u32),
68}
69
70impl ConvolutionKernel {
71 pub fn new(context: &GpuContext) -> GpuResult<Self> {
77 debug!("Creating convolution kernel");
78
79 let shader_source = Self::convolution_shader();
80 let mut shader = WgslShader::new(shader_source, "convolve");
81 let shader_module = shader.compile(context.device())?;
82
83 let bind_group_layout = create_compute_bind_group_layout(
84 context.device(),
85 &[
86 storage_buffer_layout(0, true), storage_buffer_layout(1, true), uniform_buffer_layout(2), storage_buffer_layout(3, false), ],
91 Some("ConvolutionKernel BindGroupLayout"),
92 )?;
93
94 let pipeline = ComputePipelineBuilder::new(context.device(), shader_module, "convolve")
95 .bind_group_layout(&bind_group_layout)
96 .label("ConvolutionKernel Pipeline")
97 .build()?;
98
99 Ok(Self {
100 context: context.clone(),
101 pipeline,
102 bind_group_layout,
103 workgroup_size: (16, 16),
104 })
105 }
106
107 fn convolution_shader() -> String {
109 r#"
110struct ConvolutionParams {
111 width: u32,
112 height: u32,
113 kernel_width: u32,
114 kernel_height: u32,
115}
116
117@group(0) @binding(0) var<storage, read> input: array<f32>;
118@group(0) @binding(1) var<storage, read> kernel: array<f32>;
119@group(0) @binding(2) var<uniform> params: ConvolutionParams;
120@group(0) @binding(3) var<storage, read_write> output: array<f32>;
121
122fn get_pixel(x: i32, y: i32) -> f32 {
123 // Clamp to image boundaries
124 let cx = clamp(x, 0, i32(params.width) - 1);
125 let cy = clamp(y, 0, i32(params.height) - 1);
126 return input[u32(cy) * params.width + u32(cx)];
127}
128
129@compute @workgroup_size(16, 16)
130fn convolve(@builtin(global_invocation_id) global_id: vec3<u32>) {
131 let x = global_id.x;
132 let y = global_id.y;
133
134 if (x >= params.width || y >= params.height) {
135 return;
136 }
137
138 let kernel_half_width = params.kernel_width / 2u;
139 let kernel_half_height = params.kernel_height / 2u;
140
141 var sum = 0.0;
142
143 for (var ky = 0u; ky < params.kernel_height; ky++) {
144 for (var kx = 0u; kx < params.kernel_width; kx++) {
145 let offset_x = i32(kx) - i32(kernel_half_width);
146 let offset_y = i32(ky) - i32(kernel_half_height);
147
148 let px = i32(x) + offset_x;
149 let py = i32(y) + offset_y;
150
151 let pixel_value = get_pixel(px, py);
152 let kernel_value = kernel[ky * params.kernel_width + kx];
153
154 sum += pixel_value * kernel_value;
155 }
156 }
157
158 output[y * params.width + x] = sum;
159}
160"#
161 .to_string()
162 }
163
164 pub fn execute<T: Pod>(
170 &self,
171 input: &GpuBuffer<T>,
172 kernel: &GpuBuffer<f32>,
173 params: ConvolutionParams,
174 ) -> GpuResult<GpuBuffer<T>> {
175 let expected_input_size = (params.width as usize) * (params.height as usize);
177 let expected_kernel_size = (params.kernel_width as usize) * (params.kernel_height as usize);
178
179 if input.len() != expected_input_size {
180 return Err(GpuError::invalid_kernel_params(format!(
181 "Input size mismatch: expected {}, got {}",
182 expected_input_size,
183 input.len()
184 )));
185 }
186
187 if kernel.len() != expected_kernel_size {
188 return Err(GpuError::invalid_kernel_params(format!(
189 "Kernel size mismatch: expected {}, got {}",
190 expected_kernel_size,
191 kernel.len()
192 )));
193 }
194
195 let output = GpuBuffer::new(
197 &self.context,
198 expected_input_size,
199 BufferUsages::STORAGE | BufferUsages::COPY_SRC,
200 )?;
201
202 let params_buffer = GpuBuffer::from_data(
204 &self.context,
205 &[params],
206 BufferUsages::UNIFORM | BufferUsages::COPY_DST,
207 )?;
208
209 let bind_group = self
211 .context
212 .device()
213 .create_bind_group(&BindGroupDescriptor {
214 label: Some("ConvolutionKernel BindGroup"),
215 layout: &self.bind_group_layout,
216 entries: &[
217 BindGroupEntry {
218 binding: 0,
219 resource: input.buffer().as_entire_binding(),
220 },
221 BindGroupEntry {
222 binding: 1,
223 resource: kernel.buffer().as_entire_binding(),
224 },
225 BindGroupEntry {
226 binding: 2,
227 resource: params_buffer.buffer().as_entire_binding(),
228 },
229 BindGroupEntry {
230 binding: 3,
231 resource: output.buffer().as_entire_binding(),
232 },
233 ],
234 });
235
236 let mut encoder = self
238 .context
239 .device()
240 .create_command_encoder(&CommandEncoderDescriptor {
241 label: Some("ConvolutionKernel Encoder"),
242 });
243
244 {
245 let mut compute_pass = encoder.begin_compute_pass(&ComputePassDescriptor {
246 label: Some("ConvolutionKernel Pass"),
247 timestamp_writes: None,
248 });
249
250 compute_pass.set_pipeline(&self.pipeline);
251 compute_pass.set_bind_group(0, &bind_group, &[]);
252
253 let workgroups_x = (params.width + self.workgroup_size.0 - 1) / self.workgroup_size.0;
254 let workgroups_y = (params.height + self.workgroup_size.1 - 1) / self.workgroup_size.1;
255
256 compute_pass.dispatch_workgroups(workgroups_x, workgroups_y, 1);
257 }
258
259 self.context.queue().submit(Some(encoder.finish()));
260
261 debug!(
262 "Convolved {}x{} with {}x{} kernel",
263 params.width, params.height, params.kernel_width, params.kernel_height
264 );
265
266 Ok(output)
267 }
268}
269
270pub struct Filters;
272
273impl Filters {
274 pub fn gaussian_3x3() -> Vec<f32> {
276 vec![
277 1.0 / 16.0,
278 2.0 / 16.0,
279 1.0 / 16.0,
280 2.0 / 16.0,
281 4.0 / 16.0,
282 2.0 / 16.0,
283 1.0 / 16.0,
284 2.0 / 16.0,
285 1.0 / 16.0,
286 ]
287 }
288
289 pub fn gaussian_5x5() -> Vec<f32> {
291 #[allow(clippy::excessive_precision)]
292 let kernel = vec![
293 1.0, 4.0, 6.0, 4.0, 1.0, 4.0, 16.0, 24.0, 16.0, 4.0, 6.0, 24.0, 36.0, 24.0, 6.0, 4.0,
294 16.0, 24.0, 16.0, 4.0, 1.0, 4.0, 6.0, 4.0, 1.0,
295 ];
296 let sum: f32 = kernel.iter().sum();
297 kernel.iter().map(|v| v / sum).collect()
298 }
299
300 pub fn sobel_horizontal() -> Vec<f32> {
302 vec![-1.0, 0.0, 1.0, -2.0, 0.0, 2.0, -1.0, 0.0, 1.0]
303 }
304
305 pub fn sobel_vertical() -> Vec<f32> {
307 vec![-1.0, -2.0, -1.0, 0.0, 0.0, 0.0, 1.0, 2.0, 1.0]
308 }
309
310 pub fn laplacian() -> Vec<f32> {
312 vec![0.0, 1.0, 0.0, 1.0, -4.0, 1.0, 0.0, 1.0, 0.0]
313 }
314
315 pub fn sharpen() -> Vec<f32> {
317 vec![0.0, -1.0, 0.0, -1.0, 5.0, -1.0, 0.0, -1.0, 0.0]
318 }
319
320 pub fn box_blur_3x3() -> Vec<f32> {
322 vec![
323 1.0 / 9.0,
324 1.0 / 9.0,
325 1.0 / 9.0,
326 1.0 / 9.0,
327 1.0 / 9.0,
328 1.0 / 9.0,
329 1.0 / 9.0,
330 1.0 / 9.0,
331 1.0 / 9.0,
332 ]
333 }
334
335 pub fn emboss() -> Vec<f32> {
337 vec![-2.0, -1.0, 0.0, -1.0, 1.0, 1.0, 0.0, 1.0, 2.0]
338 }
339
340 pub fn gaussian_custom(size: usize, sigma: f32) -> crate::error::GpuResult<Vec<f32>> {
346 if size % 2 == 0 {
347 return Err(crate::error::GpuError::InvalidKernelParams {
348 reason: "Kernel size must be odd".to_string(),
349 });
350 }
351
352 let center = (size / 2) as i32;
353 let mut kernel = vec![0.0; size * size];
354
355 let two_sigma_sq = 2.0 * sigma * sigma;
356 let mut sum = 0.0;
357
358 for y in 0..size {
359 for x in 0..size {
360 let dx = (x as i32 - center) as f32;
361 let dy = (y as i32 - center) as f32;
362 let dist_sq = dx * dx + dy * dy;
363
364 let value = (-dist_sq / two_sigma_sq).exp();
365 kernel[y * size + x] = value;
366 sum += value;
367 }
368 }
369
370 Ok(kernel.iter().map(|v| v / sum).collect())
372 }
373}
374
375pub fn gaussian_blur<T: Pod>(
381 context: &GpuContext,
382 input: &GpuBuffer<T>,
383 width: u32,
384 height: u32,
385 sigma: f32,
386) -> GpuResult<GpuBuffer<T>> {
387 let kernel_size = ((sigma * 6.0).ceil() as u32) | 1; let kernel_size = kernel_size.max(3).min(15); let kernel_data = Filters::gaussian_custom(kernel_size as usize, sigma)?;
392 let kernel = GpuBuffer::from_data(
393 context,
394 &kernel_data,
395 BufferUsages::STORAGE | BufferUsages::COPY_DST,
396 )?;
397
398 let conv_kernel = ConvolutionKernel::new(context)?;
399 let params = ConvolutionParams::square(width, height, kernel_size)?;
400
401 conv_kernel.execute(input, &kernel, params)
402}
403
404pub fn sobel_edge_detection<T: Pod + Zeroable>(
410 context: &GpuContext,
411 input: &GpuBuffer<T>,
412 width: u32,
413 height: u32,
414) -> GpuResult<GpuBuffer<T>> {
415 let conv_kernel = ConvolutionKernel::new(context)?;
416 let params = ConvolutionParams::square(width, height, 3)?;
417
418 let h_kernel = GpuBuffer::from_data(
420 context,
421 &Filters::sobel_horizontal(),
422 BufferUsages::STORAGE | BufferUsages::COPY_DST,
423 )?;
424 let h_edges = conv_kernel.execute(input, &h_kernel, params)?;
425
426 let v_kernel = GpuBuffer::from_data(
428 context,
429 &Filters::sobel_vertical(),
430 BufferUsages::STORAGE | BufferUsages::COPY_DST,
431 )?;
432 let _v_edges = conv_kernel.execute(input, &v_kernel, params)?;
433
434 Ok(h_edges)
438}
439
440pub fn apply_filter<T: Pod>(
446 context: &GpuContext,
447 input: &GpuBuffer<T>,
448 width: u32,
449 height: u32,
450 kernel_data: &[f32],
451 kernel_size: u32,
452) -> GpuResult<GpuBuffer<T>> {
453 let kernel = GpuBuffer::from_data(
454 context,
455 kernel_data,
456 BufferUsages::STORAGE | BufferUsages::COPY_DST,
457 )?;
458
459 let conv_kernel = ConvolutionKernel::new(context)?;
460 let params = ConvolutionParams::square(width, height, kernel_size)?;
461
462 conv_kernel.execute(input, &kernel, params)
463}
464
465#[cfg(test)]
466#[allow(clippy::panic)]
467mod tests {
468 use super::*;
469
470 #[test]
471 fn test_convolution_params() {
472 let params = ConvolutionParams::new(1024, 768, 3, 3);
473 assert!(params.is_ok());
474
475 let params = params
476 .ok()
477 .unwrap_or_else(|| panic!("Failed to create params"));
478 assert_eq!(params.kernel_center(), (1, 1));
479
480 let params = ConvolutionParams::new(1024, 768, 4, 4);
482 assert!(params.is_err());
483 }
484
485 #[test]
486 fn test_filter_kernels() {
487 let gaussian = Filters::gaussian_3x3();
488 assert_eq!(gaussian.len(), 9);
489
490 let sum: f32 = gaussian.iter().sum();
491 assert!(
492 (sum - 1.0).abs() < 1e-5,
493 "Gaussian kernel should sum to 1.0"
494 );
495
496 let sobel = Filters::sobel_horizontal();
497 assert_eq!(sobel.len(), 9);
498
499 let laplacian = Filters::laplacian();
500 assert_eq!(laplacian.len(), 9);
501 }
502
503 #[test]
504 fn test_gaussian_custom() {
505 let kernel = Filters::gaussian_custom(5, 1.0).expect("Failed to create kernel");
506 assert_eq!(kernel.len(), 25);
507
508 let sum: f32 = kernel.iter().sum();
509 assert!(
510 (sum - 1.0).abs() < 1e-5,
511 "Custom Gaussian should sum to 1.0"
512 );
513
514 let center_value = kernel[12]; for (i, &value) in kernel.iter().enumerate() {
517 if i != 12 {
518 assert!(value <= center_value);
519 }
520 }
521 }
522
523 #[tokio::test]
524 async fn test_convolution_kernel() {
525 if let Ok(context) = GpuContext::new().await {
526 if let Ok(_kernel) = ConvolutionKernel::new(&context) {
527 }
529 }
530 }
531
532 #[test]
533 fn test_gaussian_custom_even_size() {
534 let result = Filters::gaussian_custom(4, 1.0);
535 assert!(result.is_err()); }
537}