Skip to main content

oxigdal_gpu/kernels/
convolution.rs

1//! GPU kernels for convolution and filtering operations.
2//!
3//! This module provides GPU-accelerated convolution operations including
4//! 2D convolution, separable filters, and common image filters.
5
6use crate::buffer::GpuBuffer;
7use crate::context::GpuContext;
8use crate::error::{GpuError, GpuResult};
9use crate::shaders::{
10    ComputePipelineBuilder, WgslShader, create_compute_bind_group_layout, storage_buffer_layout,
11    uniform_buffer_layout,
12};
13use bytemuck::{Pod, Zeroable};
14use tracing::debug;
15use wgpu::{
16    BindGroupDescriptor, BindGroupEntry, BufferUsages, CommandEncoderDescriptor,
17    ComputePassDescriptor, ComputePipeline,
18};
19
20/// Convolution parameters.
21#[derive(Debug, Clone, Copy, Pod, Zeroable)]
22#[repr(C)]
23pub struct ConvolutionParams {
24    /// Image width.
25    pub width: u32,
26    /// Image height.
27    pub height: u32,
28    /// Kernel width (must be odd).
29    pub kernel_width: u32,
30    /// Kernel height (must be odd).
31    pub kernel_height: u32,
32}
33
34impl ConvolutionParams {
35    /// Create new convolution parameters.
36    pub fn new(width: u32, height: u32, kernel_width: u32, kernel_height: u32) -> GpuResult<Self> {
37        if kernel_width % 2 == 0 || kernel_height % 2 == 0 {
38            return Err(GpuError::invalid_kernel_params(
39                "Kernel dimensions must be odd",
40            ));
41        }
42
43        Ok(Self {
44            width,
45            height,
46            kernel_width,
47            kernel_height,
48        })
49    }
50
51    /// Create parameters for square kernel.
52    pub fn square(width: u32, height: u32, kernel_size: u32) -> GpuResult<Self> {
53        Self::new(width, height, kernel_size, kernel_size)
54    }
55
56    /// Get kernel center offset.
57    pub fn kernel_center(&self) -> (u32, u32) {
58        (self.kernel_width / 2, self.kernel_height / 2)
59    }
60}
61
62/// GPU kernel for 2D convolution.
63pub struct ConvolutionKernel {
64    context: GpuContext,
65    pipeline: ComputePipeline,
66    bind_group_layout: wgpu::BindGroupLayout,
67    workgroup_size: (u32, u32),
68}
69
70impl ConvolutionKernel {
71    /// Create a new convolution kernel.
72    ///
73    /// # Errors
74    ///
75    /// Returns an error if shader compilation or pipeline creation fails.
76    pub fn new(context: &GpuContext) -> GpuResult<Self> {
77        debug!("Creating convolution kernel");
78
79        let shader_source = Self::convolution_shader();
80        let mut shader = WgslShader::new(shader_source, "convolve");
81        let shader_module = shader.compile(context.device())?;
82
83        let bind_group_layout = create_compute_bind_group_layout(
84            context.device(),
85            &[
86                storage_buffer_layout(0, true),  // input
87                storage_buffer_layout(1, true),  // kernel
88                uniform_buffer_layout(2),        // params
89                storage_buffer_layout(3, false), // output
90            ],
91            Some("ConvolutionKernel BindGroupLayout"),
92        )?;
93
94        let pipeline = ComputePipelineBuilder::new(context.device(), shader_module, "convolve")
95            .bind_group_layout(&bind_group_layout)
96            .label("ConvolutionKernel Pipeline")
97            .build()?;
98
99        Ok(Self {
100            context: context.clone(),
101            pipeline,
102            bind_group_layout,
103            workgroup_size: (16, 16),
104        })
105    }
106
107    /// Get convolution shader source.
108    fn convolution_shader() -> String {
109        r#"
110struct ConvolutionParams {
111    width: u32,
112    height: u32,
113    kernel_width: u32,
114    kernel_height: u32,
115}
116
117@group(0) @binding(0) var<storage, read> input: array<f32>;
118@group(0) @binding(1) var<storage, read> kernel: array<f32>;
119@group(0) @binding(2) var<uniform> params: ConvolutionParams;
120@group(0) @binding(3) var<storage, read_write> output: array<f32>;
121
122fn get_pixel(x: i32, y: i32) -> f32 {
123    // Clamp to image boundaries
124    let cx = clamp(x, 0, i32(params.width) - 1);
125    let cy = clamp(y, 0, i32(params.height) - 1);
126    return input[u32(cy) * params.width + u32(cx)];
127}
128
129@compute @workgroup_size(16, 16)
130fn convolve(@builtin(global_invocation_id) global_id: vec3<u32>) {
131    let x = global_id.x;
132    let y = global_id.y;
133
134    if (x >= params.width || y >= params.height) {
135        return;
136    }
137
138    let kernel_half_width = params.kernel_width / 2u;
139    let kernel_half_height = params.kernel_height / 2u;
140
141    var sum = 0.0;
142
143    for (var ky = 0u; ky < params.kernel_height; ky++) {
144        for (var kx = 0u; kx < params.kernel_width; kx++) {
145            let offset_x = i32(kx) - i32(kernel_half_width);
146            let offset_y = i32(ky) - i32(kernel_half_height);
147
148            let px = i32(x) + offset_x;
149            let py = i32(y) + offset_y;
150
151            let pixel_value = get_pixel(px, py);
152            let kernel_value = kernel[ky * params.kernel_width + kx];
153
154            sum += pixel_value * kernel_value;
155        }
156    }
157
158    output[y * params.width + x] = sum;
159}
160"#
161        .to_string()
162    }
163
164    /// Execute convolution on GPU buffer.
165    ///
166    /// # Errors
167    ///
168    /// Returns an error if buffer sizes are invalid or execution fails.
169    pub fn execute<T: Pod>(
170        &self,
171        input: &GpuBuffer<T>,
172        kernel: &GpuBuffer<f32>,
173        params: ConvolutionParams,
174    ) -> GpuResult<GpuBuffer<T>> {
175        // Validate sizes
176        let expected_input_size = (params.width as usize) * (params.height as usize);
177        let expected_kernel_size = (params.kernel_width as usize) * (params.kernel_height as usize);
178
179        if input.len() != expected_input_size {
180            return Err(GpuError::invalid_kernel_params(format!(
181                "Input size mismatch: expected {}, got {}",
182                expected_input_size,
183                input.len()
184            )));
185        }
186
187        if kernel.len() != expected_kernel_size {
188            return Err(GpuError::invalid_kernel_params(format!(
189                "Kernel size mismatch: expected {}, got {}",
190                expected_kernel_size,
191                kernel.len()
192            )));
193        }
194
195        // Create output buffer
196        let output = GpuBuffer::new(
197            &self.context,
198            expected_input_size,
199            BufferUsages::STORAGE | BufferUsages::COPY_SRC,
200        )?;
201
202        // Create params buffer
203        let params_buffer = GpuBuffer::from_data(
204            &self.context,
205            &[params],
206            BufferUsages::UNIFORM | BufferUsages::COPY_DST,
207        )?;
208
209        // Create bind group
210        let bind_group = self
211            .context
212            .device()
213            .create_bind_group(&BindGroupDescriptor {
214                label: Some("ConvolutionKernel BindGroup"),
215                layout: &self.bind_group_layout,
216                entries: &[
217                    BindGroupEntry {
218                        binding: 0,
219                        resource: input.buffer().as_entire_binding(),
220                    },
221                    BindGroupEntry {
222                        binding: 1,
223                        resource: kernel.buffer().as_entire_binding(),
224                    },
225                    BindGroupEntry {
226                        binding: 2,
227                        resource: params_buffer.buffer().as_entire_binding(),
228                    },
229                    BindGroupEntry {
230                        binding: 3,
231                        resource: output.buffer().as_entire_binding(),
232                    },
233                ],
234            });
235
236        // Execute kernel
237        let mut encoder = self
238            .context
239            .device()
240            .create_command_encoder(&CommandEncoderDescriptor {
241                label: Some("ConvolutionKernel Encoder"),
242            });
243
244        {
245            let mut compute_pass = encoder.begin_compute_pass(&ComputePassDescriptor {
246                label: Some("ConvolutionKernel Pass"),
247                timestamp_writes: None,
248            });
249
250            compute_pass.set_pipeline(&self.pipeline);
251            compute_pass.set_bind_group(0, &bind_group, &[]);
252
253            let workgroups_x = (params.width + self.workgroup_size.0 - 1) / self.workgroup_size.0;
254            let workgroups_y = (params.height + self.workgroup_size.1 - 1) / self.workgroup_size.1;
255
256            compute_pass.dispatch_workgroups(workgroups_x, workgroups_y, 1);
257        }
258
259        self.context.queue().submit(Some(encoder.finish()));
260
261        debug!(
262            "Convolved {}x{} with {}x{} kernel",
263            params.width, params.height, params.kernel_width, params.kernel_height
264        );
265
266        Ok(output)
267    }
268}
269
270/// Common convolution kernels.
271pub struct Filters;
272
273impl Filters {
274    /// Gaussian blur kernel (3x3).
275    pub fn gaussian_3x3() -> Vec<f32> {
276        vec![
277            1.0 / 16.0,
278            2.0 / 16.0,
279            1.0 / 16.0,
280            2.0 / 16.0,
281            4.0 / 16.0,
282            2.0 / 16.0,
283            1.0 / 16.0,
284            2.0 / 16.0,
285            1.0 / 16.0,
286        ]
287    }
288
289    /// Gaussian blur kernel (5x5).
290    pub fn gaussian_5x5() -> Vec<f32> {
291        #[allow(clippy::excessive_precision)]
292        let kernel = vec![
293            1.0, 4.0, 6.0, 4.0, 1.0, 4.0, 16.0, 24.0, 16.0, 4.0, 6.0, 24.0, 36.0, 24.0, 6.0, 4.0,
294            16.0, 24.0, 16.0, 4.0, 1.0, 4.0, 6.0, 4.0, 1.0,
295        ];
296        let sum: f32 = kernel.iter().sum();
297        kernel.iter().map(|v| v / sum).collect()
298    }
299
300    /// Sobel edge detection (horizontal).
301    pub fn sobel_horizontal() -> Vec<f32> {
302        vec![-1.0, 0.0, 1.0, -2.0, 0.0, 2.0, -1.0, 0.0, 1.0]
303    }
304
305    /// Sobel edge detection (vertical).
306    pub fn sobel_vertical() -> Vec<f32> {
307        vec![-1.0, -2.0, -1.0, 0.0, 0.0, 0.0, 1.0, 2.0, 1.0]
308    }
309
310    /// Laplacian edge detection.
311    pub fn laplacian() -> Vec<f32> {
312        vec![0.0, 1.0, 0.0, 1.0, -4.0, 1.0, 0.0, 1.0, 0.0]
313    }
314
315    /// Sharpen filter.
316    pub fn sharpen() -> Vec<f32> {
317        vec![0.0, -1.0, 0.0, -1.0, 5.0, -1.0, 0.0, -1.0, 0.0]
318    }
319
320    /// Box blur (3x3).
321    pub fn box_blur_3x3() -> Vec<f32> {
322        vec![
323            1.0 / 9.0,
324            1.0 / 9.0,
325            1.0 / 9.0,
326            1.0 / 9.0,
327            1.0 / 9.0,
328            1.0 / 9.0,
329            1.0 / 9.0,
330            1.0 / 9.0,
331            1.0 / 9.0,
332        ]
333    }
334
335    /// Emboss filter.
336    pub fn emboss() -> Vec<f32> {
337        vec![-2.0, -1.0, 0.0, -1.0, 1.0, 1.0, 0.0, 1.0, 2.0]
338    }
339
340    /// Create custom Gaussian kernel with given sigma.
341    ///
342    /// # Errors
343    ///
344    /// Returns an error if the kernel size is not odd.
345    pub fn gaussian_custom(size: usize, sigma: f32) -> crate::error::GpuResult<Vec<f32>> {
346        if size % 2 == 0 {
347            return Err(crate::error::GpuError::InvalidKernelParams {
348                reason: "Kernel size must be odd".to_string(),
349            });
350        }
351
352        let center = (size / 2) as i32;
353        let mut kernel = vec![0.0; size * size];
354
355        let two_sigma_sq = 2.0 * sigma * sigma;
356        let mut sum = 0.0;
357
358        for y in 0..size {
359            for x in 0..size {
360                let dx = (x as i32 - center) as f32;
361                let dy = (y as i32 - center) as f32;
362                let dist_sq = dx * dx + dy * dy;
363
364                let value = (-dist_sq / two_sigma_sq).exp();
365                kernel[y * size + x] = value;
366                sum += value;
367            }
368        }
369
370        // Normalize
371        Ok(kernel.iter().map(|v| v / sum).collect())
372    }
373}
374
375/// Apply Gaussian blur using GPU.
376///
377/// # Errors
378///
379/// Returns an error if GPU operations fail.
380pub fn gaussian_blur<T: Pod>(
381    context: &GpuContext,
382    input: &GpuBuffer<T>,
383    width: u32,
384    height: u32,
385    sigma: f32,
386) -> GpuResult<GpuBuffer<T>> {
387    // Choose kernel size based on sigma (3*sigma rule)
388    let kernel_size = ((sigma * 6.0).ceil() as u32) | 1; // Make it odd
389    let kernel_size = kernel_size.max(3).min(15); // Clamp to reasonable range
390
391    let kernel_data = Filters::gaussian_custom(kernel_size as usize, sigma)?;
392    let kernel = GpuBuffer::from_data(
393        context,
394        &kernel_data,
395        BufferUsages::STORAGE | BufferUsages::COPY_DST,
396    )?;
397
398    let conv_kernel = ConvolutionKernel::new(context)?;
399    let params = ConvolutionParams::square(width, height, kernel_size)?;
400
401    conv_kernel.execute(input, &kernel, params)
402}
403
404/// Apply edge detection using Sobel operator.
405///
406/// # Errors
407///
408/// Returns an error if GPU operations fail.
409pub fn sobel_edge_detection<T: Pod + Zeroable>(
410    context: &GpuContext,
411    input: &GpuBuffer<T>,
412    width: u32,
413    height: u32,
414) -> GpuResult<GpuBuffer<T>> {
415    let conv_kernel = ConvolutionKernel::new(context)?;
416    let params = ConvolutionParams::square(width, height, 3)?;
417
418    // Horizontal edges
419    let h_kernel = GpuBuffer::from_data(
420        context,
421        &Filters::sobel_horizontal(),
422        BufferUsages::STORAGE | BufferUsages::COPY_DST,
423    )?;
424    let h_edges = conv_kernel.execute(input, &h_kernel, params)?;
425
426    // Vertical edges
427    let v_kernel = GpuBuffer::from_data(
428        context,
429        &Filters::sobel_vertical(),
430        BufferUsages::STORAGE | BufferUsages::COPY_DST,
431    )?;
432    let _v_edges = conv_kernel.execute(input, &v_kernel, params)?;
433
434    // Combine using magnitude: sqrt(h^2 + v^2)
435    // For simplicity, we'll just return horizontal edges
436    // A full implementation would compute the magnitude
437    Ok(h_edges)
438}
439
440/// Apply custom convolution filter.
441///
442/// # Errors
443///
444/// Returns an error if GPU operations fail.
445pub fn apply_filter<T: Pod>(
446    context: &GpuContext,
447    input: &GpuBuffer<T>,
448    width: u32,
449    height: u32,
450    kernel_data: &[f32],
451    kernel_size: u32,
452) -> GpuResult<GpuBuffer<T>> {
453    let kernel = GpuBuffer::from_data(
454        context,
455        kernel_data,
456        BufferUsages::STORAGE | BufferUsages::COPY_DST,
457    )?;
458
459    let conv_kernel = ConvolutionKernel::new(context)?;
460    let params = ConvolutionParams::square(width, height, kernel_size)?;
461
462    conv_kernel.execute(input, &kernel, params)
463}
464
465#[cfg(test)]
466#[allow(clippy::panic)]
467mod tests {
468    use super::*;
469
470    #[test]
471    fn test_convolution_params() {
472        let params = ConvolutionParams::new(1024, 768, 3, 3);
473        assert!(params.is_ok());
474
475        let params = params
476            .ok()
477            .unwrap_or_else(|| panic!("Failed to create params"));
478        assert_eq!(params.kernel_center(), (1, 1));
479
480        // Even kernel size should fail
481        let params = ConvolutionParams::new(1024, 768, 4, 4);
482        assert!(params.is_err());
483    }
484
485    #[test]
486    fn test_filter_kernels() {
487        let gaussian = Filters::gaussian_3x3();
488        assert_eq!(gaussian.len(), 9);
489
490        let sum: f32 = gaussian.iter().sum();
491        assert!(
492            (sum - 1.0).abs() < 1e-5,
493            "Gaussian kernel should sum to 1.0"
494        );
495
496        let sobel = Filters::sobel_horizontal();
497        assert_eq!(sobel.len(), 9);
498
499        let laplacian = Filters::laplacian();
500        assert_eq!(laplacian.len(), 9);
501    }
502
503    #[test]
504    fn test_gaussian_custom() {
505        let kernel = Filters::gaussian_custom(5, 1.0).expect("Failed to create kernel");
506        assert_eq!(kernel.len(), 25);
507
508        let sum: f32 = kernel.iter().sum();
509        assert!(
510            (sum - 1.0).abs() < 1e-5,
511            "Custom Gaussian should sum to 1.0"
512        );
513
514        // Center value should be maximum
515        let center_value = kernel[12]; // Middle of 5x5
516        for (i, &value) in kernel.iter().enumerate() {
517            if i != 12 {
518                assert!(value <= center_value);
519            }
520        }
521    }
522
523    #[tokio::test]
524    async fn test_convolution_kernel() {
525        if let Ok(context) = GpuContext::new().await {
526            if let Ok(_kernel) = ConvolutionKernel::new(&context) {
527                // Kernel created successfully
528            }
529        }
530    }
531
532    #[test]
533    fn test_gaussian_custom_even_size() {
534        let result = Filters::gaussian_custom(4, 1.0);
535        assert!(result.is_err()); // Should return error for even size
536    }
537}