Skip to main content

oximedia_gpu/ops/
filter.rs

1//! Convolution filter operations (blur, sharpen, edge detection)
2
3use crate::{
4    shader::{BindGroupLayoutBuilder, ShaderCompiler, ShaderSource},
5    GpuDevice, GpuError, Result,
6};
7use bytemuck::{Pod, Zeroable};
8use once_cell::sync::OnceCell;
9use wgpu::{BindGroup, BindGroupLayout, ComputePipeline};
10
11use super::utils;
12
13#[repr(C)]
14#[derive(Copy, Clone, Pod, Zeroable)]
15struct FilterParams {
16    width: u32,
17    height: u32,
18    stride: u32,
19    kernel_size: u32,
20    normalize: u32,
21    filter_type: u32,
22    padding: u32,
23    sigma: f32,
24}
25
26/// Convolution filter operations
27pub struct FilterOperation;
28
29impl FilterOperation {
30    /// Apply Gaussian blur
31    ///
32    /// # Arguments
33    ///
34    /// * `device` - GPU device
35    /// * `input` - Input image buffer (packed RGBA format)
36    /// * `output` - Output image buffer (packed RGBA format)
37    /// * `width` - Image width
38    /// * `height` - Image height
39    /// * `sigma` - Blur radius (standard deviation)
40    ///
41    /// # Errors
42    ///
43    /// Returns an error if buffer sizes are invalid or if the GPU operation fails.
44    #[allow(clippy::too_many_arguments)]
45    pub fn gaussian_blur(
46        device: &GpuDevice,
47        input: &[u8],
48        output: &mut [u8],
49        width: u32,
50        height: u32,
51        sigma: f32,
52    ) -> Result<()> {
53        utils::validate_dimensions(width, height)?;
54        utils::validate_buffer_size(input, width, height, 4)?;
55        utils::validate_buffer_size(output, width, height, 4)?;
56
57        let kernel_size = Self::calculate_kernel_size(sigma);
58        let pipeline = Self::get_gaussian_pipeline(device)?;
59        let layout = Self::get_bind_group_layout(device)?;
60
61        Self::execute_filter(
62            device,
63            pipeline,
64            layout,
65            input,
66            output,
67            width,
68            height,
69            kernel_size,
70            1, // Gaussian filter type
71            sigma,
72        )
73    }
74
75    /// Apply sharpening filter (unsharp mask)
76    ///
77    /// # Arguments
78    ///
79    /// * `device` - GPU device
80    /// * `input` - Input image buffer (packed RGBA format)
81    /// * `output` - Output image buffer (packed RGBA format)
82    /// * `width` - Image width
83    /// * `height` - Image height
84    /// * `amount` - Sharpening strength
85    ///
86    /// # Errors
87    ///
88    /// Returns an error if buffer sizes are invalid or if the GPU operation fails.
89    #[allow(clippy::too_many_arguments)]
90    pub fn sharpen(
91        device: &GpuDevice,
92        input: &[u8],
93        output: &mut [u8],
94        width: u32,
95        height: u32,
96        amount: f32,
97    ) -> Result<()> {
98        utils::validate_dimensions(width, height)?;
99        utils::validate_buffer_size(input, width, height, 4)?;
100        utils::validate_buffer_size(output, width, height, 4)?;
101
102        let pipeline = Self::get_sharpen_pipeline(device)?;
103        let layout = Self::get_bind_group_layout(device)?;
104
105        Self::execute_filter(
106            device, pipeline, layout, input, output, width, height,
107            5, // Kernel size for sharpening
108            2, // Sharpen filter type
109            amount,
110        )
111    }
112
113    /// Detect edges using Sobel operator
114    ///
115    /// # Arguments
116    ///
117    /// * `device` - GPU device
118    /// * `input` - Input image buffer (packed RGBA format)
119    /// * `output` - Output image buffer (packed RGBA format)
120    /// * `width` - Image width
121    /// * `height` - Image height
122    ///
123    /// # Errors
124    ///
125    /// Returns an error if buffer sizes are invalid or if the GPU operation fails.
126    pub fn edge_detect(
127        device: &GpuDevice,
128        input: &[u8],
129        output: &mut [u8],
130        width: u32,
131        height: u32,
132    ) -> Result<()> {
133        utils::validate_dimensions(width, height)?;
134        utils::validate_buffer_size(input, width, height, 4)?;
135        utils::validate_buffer_size(output, width, height, 4)?;
136
137        let pipeline = Self::get_edge_detect_pipeline(device)?;
138        let layout = Self::get_bind_group_layout(device)?;
139
140        Self::execute_filter(
141            device, pipeline, layout, input, output, width, height, 3, // 3x3 Sobel kernel
142            3, // Edge detect filter type
143            0.0,
144        )
145    }
146
147    /// Apply custom convolution kernel
148    ///
149    /// # Arguments
150    ///
151    /// * `device` - GPU device
152    /// * `input` - Input image buffer (packed RGBA format)
153    /// * `output` - Output image buffer (packed RGBA format)
154    /// * `width` - Image width
155    /// * `height` - Image height
156    /// * `kernel` - Convolution kernel (must be square and odd-sized)
157    /// * `normalize` - Whether to normalize the kernel
158    ///
159    /// # Errors
160    ///
161    /// Returns an error if buffer sizes are invalid or if the GPU operation fails.
162    #[allow(clippy::too_many_arguments)]
163    pub fn convolve(
164        device: &GpuDevice,
165        input: &[u8],
166        output: &mut [u8],
167        width: u32,
168        height: u32,
169        kernel: &[f32],
170        normalize: bool,
171    ) -> Result<()> {
172        utils::validate_dimensions(width, height)?;
173        utils::validate_buffer_size(input, width, height, 4)?;
174        utils::validate_buffer_size(output, width, height, 4)?;
175
176        let kernel_size = (kernel.len() as f32).sqrt() as u32;
177        if kernel_size * kernel_size != kernel.len() as u32 {
178            return Err(GpuError::Internal("Kernel must be square".to_string()));
179        }
180        if kernel_size % 2 == 0 {
181            return Err(GpuError::Internal("Kernel size must be odd".to_string()));
182        }
183
184        let pipeline = Self::get_convolve_pipeline(device)?;
185        let layout = Self::get_bind_group_layout_with_kernel(device)?;
186
187        Self::execute_convolve(
188            device,
189            pipeline,
190            layout,
191            input,
192            output,
193            width,
194            height,
195            kernel,
196            kernel_size,
197            normalize,
198        )
199    }
200
201    #[allow(clippy::too_many_arguments)]
202    fn execute_filter(
203        device: &GpuDevice,
204        pipeline: &ComputePipeline,
205        layout: &BindGroupLayout,
206        input: &[u8],
207        output: &mut [u8],
208        width: u32,
209        height: u32,
210        kernel_size: u32,
211        filter_type: u32,
212        sigma: f32,
213    ) -> Result<()> {
214        // Create buffers
215        let input_buffer = utils::create_storage_buffer(device, input.len() as u64)?;
216        let output_buffer = utils::create_storage_buffer(device, output.len() as u64)?;
217
218        // Upload input data
219        device.queue().write_buffer(input_buffer.buffer(), 0, input);
220
221        // Create uniform buffer for parameters
222        let params = FilterParams {
223            width,
224            height,
225            stride: width,
226            kernel_size,
227            normalize: 1,
228            filter_type,
229            padding: 0,
230            sigma,
231        };
232        let params_bytes = bytemuck::bytes_of(&params);
233        let params_buffer = utils::create_uniform_buffer(device, params_bytes)?;
234
235        // Create bind group
236        let compiler = ShaderCompiler::new(device);
237        let bind_group = compiler.create_bind_group(
238            "Filter Bind Group",
239            layout,
240            &[
241                wgpu::BindGroupEntry {
242                    binding: 0,
243                    resource: input_buffer.buffer().as_entire_binding(),
244                },
245                wgpu::BindGroupEntry {
246                    binding: 1,
247                    resource: output_buffer.buffer().as_entire_binding(),
248                },
249                wgpu::BindGroupEntry {
250                    binding: 2,
251                    resource: params_buffer.buffer().as_entire_binding(),
252                },
253            ],
254        );
255
256        // Execute compute pass
257        Self::dispatch_compute(device, pipeline, &bind_group, width, height)?;
258
259        // Read back results
260        let readback_buffer = utils::create_readback_buffer(device, output.len() as u64)?;
261        let mut encoder = device
262            .device()
263            .create_command_encoder(&wgpu::CommandEncoderDescriptor {
264                label: Some("Filter Copy Encoder"),
265            });
266
267        output_buffer.copy_to(&mut encoder, &readback_buffer, 0, 0, output.len() as u64)?;
268
269        device.queue().submit(Some(encoder.finish()));
270        device.wait();
271
272        let result = readback_buffer.read(device, 0, output.len() as u64)?;
273        output.copy_from_slice(&result);
274
275        Ok(())
276    }
277
278    #[allow(clippy::too_many_arguments)]
279    fn execute_convolve(
280        device: &GpuDevice,
281        pipeline: &ComputePipeline,
282        layout: &BindGroupLayout,
283        input: &[u8],
284        output: &mut [u8],
285        width: u32,
286        height: u32,
287        kernel: &[f32],
288        kernel_size: u32,
289        normalize: bool,
290    ) -> Result<()> {
291        // Create buffers
292        let input_buffer = utils::create_storage_buffer(device, input.len() as u64)?;
293        let output_buffer = utils::create_storage_buffer(device, output.len() as u64)?;
294
295        // Upload input data
296        device.queue().write_buffer(input_buffer.buffer(), 0, input);
297
298        // Create kernel buffer
299        let kernel_bytes = bytemuck::cast_slice(kernel);
300        let kernel_buffer = utils::create_storage_buffer(device, kernel_bytes.len() as u64)?;
301        device
302            .queue()
303            .write_buffer(kernel_buffer.buffer(), 0, kernel_bytes);
304
305        // Create uniform buffer for parameters
306        let params = FilterParams {
307            width,
308            height,
309            stride: width,
310            kernel_size,
311            normalize: u32::from(normalize),
312            filter_type: 0, // Custom kernel
313            padding: 0,
314            sigma: 0.0,
315        };
316        let params_bytes = bytemuck::bytes_of(&params);
317        let params_buffer = utils::create_uniform_buffer(device, params_bytes)?;
318
319        // Create bind group
320        let compiler = ShaderCompiler::new(device);
321        let bind_group = compiler.create_bind_group(
322            "Filter Bind Group",
323            layout,
324            &[
325                wgpu::BindGroupEntry {
326                    binding: 0,
327                    resource: input_buffer.buffer().as_entire_binding(),
328                },
329                wgpu::BindGroupEntry {
330                    binding: 1,
331                    resource: output_buffer.buffer().as_entire_binding(),
332                },
333                wgpu::BindGroupEntry {
334                    binding: 2,
335                    resource: params_buffer.buffer().as_entire_binding(),
336                },
337                wgpu::BindGroupEntry {
338                    binding: 3,
339                    resource: kernel_buffer.buffer().as_entire_binding(),
340                },
341            ],
342        );
343
344        // Execute compute pass
345        Self::dispatch_compute(device, pipeline, &bind_group, width, height)?;
346
347        // Read back results
348        let readback_buffer = utils::create_readback_buffer(device, output.len() as u64)?;
349        let mut encoder = device
350            .device()
351            .create_command_encoder(&wgpu::CommandEncoderDescriptor {
352                label: Some("Filter Copy Encoder"),
353            });
354
355        output_buffer.copy_to(&mut encoder, &readback_buffer, 0, 0, output.len() as u64)?;
356
357        device.queue().submit(Some(encoder.finish()));
358        device.wait();
359
360        let result = readback_buffer.read(device, 0, output.len() as u64)?;
361        output.copy_from_slice(&result);
362
363        Ok(())
364    }
365
366    fn dispatch_compute(
367        device: &GpuDevice,
368        pipeline: &ComputePipeline,
369        bind_group: &BindGroup,
370        width: u32,
371        height: u32,
372    ) -> Result<()> {
373        let mut encoder = device
374            .device()
375            .create_command_encoder(&wgpu::CommandEncoderDescriptor {
376                label: Some("Filter Compute Encoder"),
377            });
378
379        {
380            let mut compute_pass = encoder.begin_compute_pass(&wgpu::ComputePassDescriptor {
381                label: Some("Filter Compute Pass"),
382                timestamp_writes: None,
383            });
384
385            compute_pass.set_pipeline(pipeline);
386            compute_pass.set_bind_group(0, bind_group, &[]);
387
388            let (dispatch_x, dispatch_y) = utils::calculate_dispatch_size(width, height, (16, 16));
389            compute_pass.dispatch_workgroups(dispatch_x, dispatch_y, 1);
390        }
391
392        device.queue().submit(Some(encoder.finish()));
393        Ok(())
394    }
395
396    fn calculate_kernel_size(sigma: f32) -> u32 {
397        // Use 3-sigma rule: kernel size = 2 * ceil(3 * sigma) + 1
398        let radius = (3.0 * sigma).ceil() as u32;
399        2 * radius + 1
400    }
401
402    fn get_bind_group_layout(device: &GpuDevice) -> Result<&'static BindGroupLayout> {
403        static LAYOUT: OnceCell<BindGroupLayout> = OnceCell::new();
404
405        Ok(LAYOUT.get_or_init(|| {
406            let compiler = ShaderCompiler::new(device);
407            let entries = BindGroupLayoutBuilder::new()
408                .add_storage_buffer_read_only(0) // input
409                .add_storage_buffer(1) // output
410                .add_uniform_buffer(2) // params
411                .build();
412
413            compiler.create_bind_group_layout("Filter Bind Group Layout", &entries)
414        }))
415    }
416
417    fn get_bind_group_layout_with_kernel(device: &GpuDevice) -> Result<&'static BindGroupLayout> {
418        static LAYOUT: OnceCell<BindGroupLayout> = OnceCell::new();
419
420        Ok(LAYOUT.get_or_init(|| {
421            let compiler = ShaderCompiler::new(device);
422            let entries = BindGroupLayoutBuilder::new()
423                .add_storage_buffer_read_only(0) // input
424                .add_storage_buffer(1) // output
425                .add_uniform_buffer(2) // params
426                .add_storage_buffer_read_only(3) // kernel
427                .build();
428
429            compiler.create_bind_group_layout("Filter Bind Group Layout (with kernel)", &entries)
430        }))
431    }
432
433    fn get_gaussian_pipeline(device: &GpuDevice) -> Result<&'static ComputePipeline> {
434        static PIPELINE: OnceCell<ComputePipeline> = OnceCell::new();
435
436        Ok(PIPELINE.get_or_init(|| {
437            let compiler = ShaderCompiler::new(device);
438            let shader = compiler
439                .compile(
440                    "Filter Shader",
441                    ShaderSource::Embedded(crate::shader::embedded::FILTER_SHADER),
442                )
443                .expect("Failed to compile filter shader");
444
445            let layout =
446                Self::get_bind_group_layout(device).expect("Failed to create bind group layout");
447
448            compiler
449                .create_pipeline("Gaussian Blur Pipeline", &shader, "convolve_main", layout)
450                .expect("Failed to create pipeline")
451        }))
452    }
453
454    fn get_sharpen_pipeline(device: &GpuDevice) -> Result<&'static ComputePipeline> {
455        static PIPELINE: OnceCell<ComputePipeline> = OnceCell::new();
456
457        Ok(PIPELINE.get_or_init(|| {
458            let compiler = ShaderCompiler::new(device);
459            let shader = compiler
460                .compile(
461                    "Filter Shader",
462                    ShaderSource::Embedded(crate::shader::embedded::FILTER_SHADER),
463                )
464                .expect("Failed to compile filter shader");
465
466            let layout =
467                Self::get_bind_group_layout(device).expect("Failed to create bind group layout");
468
469            compiler
470                .create_pipeline("Sharpen Pipeline", &shader, "unsharp_mask", layout)
471                .expect("Failed to create pipeline")
472        }))
473    }
474
475    fn get_edge_detect_pipeline(device: &GpuDevice) -> Result<&'static ComputePipeline> {
476        static PIPELINE: OnceCell<ComputePipeline> = OnceCell::new();
477
478        Ok(PIPELINE.get_or_init(|| {
479            let compiler = ShaderCompiler::new(device);
480            let shader = compiler
481                .compile(
482                    "Filter Shader",
483                    ShaderSource::Embedded(crate::shader::embedded::FILTER_SHADER),
484                )
485                .expect("Failed to compile filter shader");
486
487            let layout =
488                Self::get_bind_group_layout(device).expect("Failed to create bind group layout");
489
490            compiler
491                .create_pipeline("Edge Detect Pipeline", &shader, "edge_detect", layout)
492                .expect("Failed to create pipeline")
493        }))
494    }
495
496    fn get_convolve_pipeline(device: &GpuDevice) -> Result<&'static ComputePipeline> {
497        static PIPELINE: OnceCell<ComputePipeline> = OnceCell::new();
498
499        Ok(PIPELINE.get_or_init(|| {
500            let compiler = ShaderCompiler::new(device);
501            let shader = compiler
502                .compile(
503                    "Filter Shader",
504                    ShaderSource::Embedded(crate::shader::embedded::FILTER_SHADER),
505                )
506                .expect("Failed to compile filter shader");
507
508            let layout = Self::get_bind_group_layout_with_kernel(device)
509                .expect("Failed to create bind group layout");
510
511            compiler
512                .create_pipeline("Convolve Pipeline", &shader, "convolve_main", layout)
513                .expect("Failed to create pipeline")
514        }))
515    }
516}