Skip to main content

oximedia_gpu/kernels/
filter.rs

1//! Image filtering and convolution kernels
2
3use crate::{GpuDevice, Result};
4
5/// Filter operation type
6#[derive(Debug, Clone, Copy, PartialEq, Eq)]
7pub enum FilterType {
8    /// Gaussian blur
9    GaussianBlur,
10    /// Box blur (simple average)
11    BoxBlur,
12    /// Median filter
13    Median,
14    /// Bilateral filter (edge-preserving)
15    Bilateral,
16    /// Unsharp mask (sharpening)
17    UnsharpMask,
18    /// Sobel edge detection
19    Sobel,
20    /// Scharr edge detection
21    Scharr,
22    /// Laplacian edge detection
23    Laplacian,
24    /// Custom convolution
25    Custom,
26}
27
28/// Convolution kernel
29pub struct ConvolutionKernel {
30    kernel: Vec<f32>,
31    width: u32,
32    height: u32,
33    normalize: bool,
34}
35
36impl ConvolutionKernel {
37    /// Create a new convolution kernel
38    ///
39    /// # Arguments
40    ///
41    /// * `kernel` - Kernel weights (must be width * height in size)
42    /// * `width` - Kernel width (must be odd)
43    /// * `height` - Kernel height (must be odd)
44    /// * `normalize` - Whether to normalize the kernel
45    pub fn new(kernel: Vec<f32>, width: u32, height: u32, normalize: bool) -> Result<Self> {
46        if kernel.len() != (width * height) as usize {
47            return Err(crate::GpuError::Internal(
48                "Kernel size mismatch".to_string(),
49            ));
50        }
51
52        if width % 2 == 0 || height % 2 == 0 {
53            return Err(crate::GpuError::Internal(
54                "Kernel dimensions must be odd".to_string(),
55            ));
56        }
57
58        Ok(Self {
59            kernel,
60            width,
61            height,
62            normalize,
63        })
64    }
65
66    /// Infallible constructor for hardcoded kernels whose dimensions are
67    /// compile-time known to be valid (odd, matching length).
68    fn from_hardcoded(kernel: Vec<f32>, width: u32, height: u32, normalize: bool) -> Self {
69        debug_assert_eq!(kernel.len(), (width * height) as usize);
70        debug_assert!(width % 2 == 1 && height % 2 == 1);
71        Self {
72            kernel,
73            width,
74            height,
75            normalize,
76        }
77    }
78
79    /// Create a Sobel X kernel (3x3)
80    #[must_use]
81    pub fn sobel_x() -> Self {
82        Self::from_hardcoded(
83            vec![-1.0, 0.0, 1.0, -2.0, 0.0, 2.0, -1.0, 0.0, 1.0],
84            3,
85            3,
86            false,
87        )
88    }
89
90    /// Create a Sobel Y kernel (3x3)
91    #[must_use]
92    pub fn sobel_y() -> Self {
93        Self::from_hardcoded(
94            vec![-1.0, -2.0, -1.0, 0.0, 0.0, 0.0, 1.0, 2.0, 1.0],
95            3,
96            3,
97            false,
98        )
99    }
100
101    /// Create a Laplacian kernel (3x3)
102    #[must_use]
103    pub fn laplacian() -> Self {
104        Self::from_hardcoded(
105            vec![0.0, 1.0, 0.0, 1.0, -4.0, 1.0, 0.0, 1.0, 0.0],
106            3,
107            3,
108            false,
109        )
110    }
111
112    /// Create a box blur kernel
113    pub fn box_blur(size: u32) -> Result<Self> {
114        let total = (size * size) as usize;
115        let value = 1.0 / total as f32;
116        let kernel = vec![value; total];
117        Self::new(kernel, size, size, false)
118    }
119
120    /// Create a sharpening kernel
121    #[must_use]
122    pub fn sharpen() -> Self {
123        Self::from_hardcoded(
124            vec![0.0, -1.0, 0.0, -1.0, 5.0, -1.0, 0.0, -1.0, 0.0],
125            3,
126            3,
127            false,
128        )
129    }
130
131    /// Get the kernel data
132    #[must_use]
133    pub fn data(&self) -> &[f32] {
134        &self.kernel
135    }
136
137    /// Get kernel dimensions
138    #[must_use]
139    pub fn dimensions(&self) -> (u32, u32) {
140        (self.width, self.height)
141    }
142
143    /// Check if normalization is enabled
144    #[must_use]
145    pub fn is_normalized(&self) -> bool {
146        self.normalize
147    }
148
149    /// Apply the convolution kernel
150    ///
151    /// # Arguments
152    ///
153    /// * `device` - GPU device
154    /// * `input` - Input image buffer
155    /// * `output` - Output image buffer
156    /// * `width` - Image width
157    /// * `height` - Image height
158    ///
159    /// # Errors
160    ///
161    /// Returns an error if the operation fails.
162    pub fn apply(
163        &self,
164        device: &GpuDevice,
165        input: &[u8],
166        output: &mut [u8],
167        width: u32,
168        height: u32,
169    ) -> Result<()> {
170        crate::ops::FilterOperation::convolve(
171            device,
172            input,
173            output,
174            width,
175            height,
176            &self.kernel,
177            self.normalize,
178        )
179    }
180}
181
182/// Image filter kernel
183pub struct FilterKernel {
184    filter_type: FilterType,
185    sigma: f32,
186    /// Range (colour) sigma for the bilateral filter. Unused by other filter types.
187    sigma_range: f32,
188    kernel_size: u32,
189}
190
191impl FilterKernel {
192    /// Create a new filter kernel
193    #[must_use]
194    pub fn new(filter_type: FilterType, sigma: f32, kernel_size: u32) -> Self {
195        Self {
196            filter_type,
197            sigma,
198            sigma_range: 50.0, // sensible default for bilateral; irrelevant for others
199            kernel_size,
200        }
201    }
202
203    /// Create a Gaussian blur filter
204    #[must_use]
205    pub fn gaussian_blur(sigma: f32) -> Self {
206        let kernel_size = Self::gaussian_kernel_size(sigma);
207        Self::new(FilterType::GaussianBlur, sigma, kernel_size)
208    }
209
210    /// Create a box blur filter
211    #[must_use]
212    pub fn box_blur(radius: u32) -> Self {
213        let kernel_size = radius * 2 + 1;
214        Self::new(FilterType::BoxBlur, 0.0, kernel_size)
215    }
216
217    /// Create an unsharp mask filter (sharpening)
218    #[must_use]
219    pub fn sharpen(amount: f32) -> Self {
220        Self::new(FilterType::UnsharpMask, amount, 5)
221    }
222
223    /// Create a Sobel edge detection filter
224    #[must_use]
225    pub fn sobel() -> Self {
226        Self::new(FilterType::Sobel, 0.0, 3)
227    }
228
229    /// Create a bilateral filter
230    #[must_use]
231    pub fn bilateral(sigma_spatial: f32, sigma_range: f32) -> Self {
232        let kernel_size = Self::gaussian_kernel_size(sigma_spatial);
233        Self {
234            filter_type: FilterType::Bilateral,
235            sigma: sigma_spatial,
236            sigma_range,
237            kernel_size,
238        }
239    }
240
241    /// Execute the filter operation
242    ///
243    /// # Arguments
244    ///
245    /// * `device` - GPU device
246    /// * `input` - Input image buffer
247    /// * `output` - Output image buffer
248    /// * `width` - Image width
249    /// * `height` - Image height
250    ///
251    /// # Errors
252    ///
253    /// Returns an error if the operation fails.
254    pub fn execute(
255        &self,
256        device: &GpuDevice,
257        input: &[u8],
258        output: &mut [u8],
259        width: u32,
260        height: u32,
261    ) -> Result<()> {
262        match self.filter_type {
263            FilterType::GaussianBlur => crate::ops::FilterOperation::gaussian_blur(
264                device, input, output, width, height, self.sigma,
265            ),
266            FilterType::UnsharpMask => crate::ops::FilterOperation::sharpen(
267                device, input, output, width, height, self.sigma,
268            ),
269            FilterType::Sobel | FilterType::Scharr | FilterType::Laplacian => {
270                crate::ops::FilterOperation::edge_detect(device, input, output, width, height)
271            }
272            FilterType::BoxBlur => {
273                // Derive radius from kernel_size (kernel_size = 2*radius+1).
274                let radius = self.kernel_size / 2;
275                let result = crate::ops::box_blur(input, width, height, 4, radius)?;
276                if result.len() != output.len() {
277                    return Err(crate::GpuError::InvalidBufferSize {
278                        expected: output.len(),
279                        actual: result.len(),
280                    });
281                }
282                output.copy_from_slice(&result);
283                Ok(())
284            }
285            FilterType::Median => {
286                let radius = self.kernel_size / 2;
287                let result = crate::ops::median_filter(input, width, height, 4, radius)?;
288                if result.len() != output.len() {
289                    return Err(crate::GpuError::InvalidBufferSize {
290                        expected: output.len(),
291                        actual: result.len(),
292                    });
293                }
294                output.copy_from_slice(&result);
295                Ok(())
296            }
297            FilterType::Bilateral => {
298                let result = crate::ops::bilateral_filter(
299                    input,
300                    width,
301                    height,
302                    4,
303                    self.sigma,
304                    self.sigma_range,
305                )?;
306                if result.len() != output.len() {
307                    return Err(crate::GpuError::InvalidBufferSize {
308                        expected: output.len(),
309                        actual: result.len(),
310                    });
311                }
312                output.copy_from_slice(&result);
313                Ok(())
314            }
315            FilterType::Custom => {
316                // FilterType::Custom carries no embedded kernel data; the kernel-bearing
317                // path uses ConvolutionKernel::apply() directly. Return a descriptive
318                // error so callers know to switch to that API.
319                Err(crate::GpuError::NotSupported(
320                    "FilterType::Custom requires a ConvolutionKernel — use ConvolutionKernel::apply() instead of FilterKernel::execute()".to_string(),
321                ))
322            }
323        }
324    }
325
326    /// Get the filter type
327    #[must_use]
328    pub fn filter_type(&self) -> FilterType {
329        self.filter_type
330    }
331
332    /// Get the sigma parameter
333    #[must_use]
334    pub fn sigma(&self) -> f32 {
335        self.sigma
336    }
337
338    /// Get the kernel size
339    #[must_use]
340    pub fn kernel_size(&self) -> u32 {
341        self.kernel_size
342    }
343
344    /// Calculate Gaussian kernel size from sigma
345    fn gaussian_kernel_size(sigma: f32) -> u32 {
346        let radius = (3.0 * sigma).ceil() as u32;
347        2 * radius + 1
348    }
349
350    /// Estimate FLOPS for filter operation
351    #[must_use]
352    pub fn estimate_flops(width: u32, height: u32, kernel_size: u32) -> u64 {
353        let pixels = u64::from(width) * u64::from(height);
354        let ops_per_pixel = u64::from(kernel_size) * u64::from(kernel_size) * 4; // 4 channels
355        pixels * ops_per_pixel * 2 // multiply + add
356    }
357}
358
359/// Separable filter for optimized 2D filtering
360pub struct SeparableFilter {
361    horizontal_kernel: Vec<f32>,
362    vertical_kernel: Vec<f32>,
363}
364
365impl SeparableFilter {
366    /// Create a new separable filter
367    #[must_use]
368    pub fn new(horizontal: Vec<f32>, vertical: Vec<f32>) -> Self {
369        Self {
370            horizontal_kernel: horizontal,
371            vertical_kernel: vertical,
372        }
373    }
374
375    /// Create a Gaussian separable filter
376    #[must_use]
377    pub fn gaussian(sigma: f32, size: u32) -> Self {
378        let kernel = Self::gaussian_kernel_1d(sigma, size);
379        Self::new(kernel.clone(), kernel)
380    }
381
382    /// Generate 1D Gaussian kernel
383    fn gaussian_kernel_1d(sigma: f32, size: u32) -> Vec<f32> {
384        let radius = (size / 2) as i32;
385        let mut kernel = Vec::with_capacity(size as usize);
386        let two_sigma_sq = 2.0 * sigma * sigma;
387
388        let mut sum = 0.0;
389        for i in -radius..=radius {
390            let value = (-((i * i) as f32) / two_sigma_sq).exp();
391            kernel.push(value);
392            sum += value;
393        }
394
395        // Normalize
396        for value in &mut kernel {
397            *value /= sum;
398        }
399
400        kernel
401    }
402
403    /// Get the horizontal kernel
404    #[must_use]
405    pub fn horizontal(&self) -> &[f32] {
406        &self.horizontal_kernel
407    }
408
409    /// Get the vertical kernel
410    #[must_use]
411    pub fn vertical(&self) -> &[f32] {
412        &self.vertical_kernel
413    }
414}
415
416#[cfg(test)]
417mod tests {
418    use super::*;
419
420    #[test]
421    fn test_convolution_kernel_creation() {
422        let kernel = ConvolutionKernel::sobel_x();
423        assert_eq!(kernel.dimensions(), (3, 3));
424        assert_eq!(kernel.data().len(), 9);
425
426        let kernel = ConvolutionKernel::laplacian();
427        assert_eq!(kernel.dimensions(), (3, 3));
428    }
429
430    #[test]
431    fn test_filter_kernel_creation() {
432        let filter = FilterKernel::gaussian_blur(2.0);
433        assert_eq!(filter.filter_type(), FilterType::GaussianBlur);
434        assert_eq!(filter.sigma(), 2.0);
435
436        let filter = FilterKernel::sobel();
437        assert_eq!(filter.filter_type(), FilterType::Sobel);
438        assert_eq!(filter.kernel_size(), 3);
439    }
440
441    #[test]
442    fn test_separable_filter() {
443        let filter = SeparableFilter::gaussian(1.0, 5);
444        assert_eq!(filter.horizontal().len(), 5);
445        assert_eq!(filter.vertical().len(), 5);
446
447        // Check normalization
448        let sum: f32 = filter.horizontal().iter().sum();
449        assert!((sum - 1.0).abs() < 0.001);
450    }
451
452    #[test]
453    fn test_box_blur_kernel() {
454        let kernel =
455            ConvolutionKernel::box_blur(3).expect("box blur kernel creation should succeed");
456        assert_eq!(kernel.dimensions(), (3, 3));
457        let expected_value = 1.0 / 9.0;
458        for &value in kernel.data() {
459            assert!((value - expected_value).abs() < 0.001);
460        }
461    }
462}