Skip to main content

oximedia_gpu/kernels/
filter.rs

1//! Image filtering and convolution kernels
2
3use crate::{GpuDevice, Result};
4
5/// Filter operation type
6#[derive(Debug, Clone, Copy, PartialEq, Eq)]
7pub enum FilterType {
8    /// Gaussian blur
9    GaussianBlur,
10    /// Box blur (simple average)
11    BoxBlur,
12    /// Median filter
13    Median,
14    /// Bilateral filter (edge-preserving)
15    Bilateral,
16    /// Unsharp mask (sharpening)
17    UnsharpMask,
18    /// Sobel edge detection
19    Sobel,
20    /// Scharr edge detection
21    Scharr,
22    /// Laplacian edge detection
23    Laplacian,
24    /// Custom convolution
25    Custom,
26}
27
28/// Convolution kernel
29pub struct ConvolutionKernel {
30    kernel: Vec<f32>,
31    width: u32,
32    height: u32,
33    normalize: bool,
34}
35
36impl ConvolutionKernel {
37    /// Create a new convolution kernel
38    ///
39    /// # Arguments
40    ///
41    /// * `kernel` - Kernel weights (must be width * height in size)
42    /// * `width` - Kernel width (must be odd)
43    /// * `height` - Kernel height (must be odd)
44    /// * `normalize` - Whether to normalize the kernel
45    pub fn new(kernel: Vec<f32>, width: u32, height: u32, normalize: bool) -> Result<Self> {
46        if kernel.len() != (width * height) as usize {
47            return Err(crate::GpuError::Internal(
48                "Kernel size mismatch".to_string(),
49            ));
50        }
51
52        if width % 2 == 0 || height % 2 == 0 {
53            return Err(crate::GpuError::Internal(
54                "Kernel dimensions must be odd".to_string(),
55            ));
56        }
57
58        Ok(Self {
59            kernel,
60            width,
61            height,
62            normalize,
63        })
64    }
65
66    /// Infallible constructor for hardcoded kernels whose dimensions are
67    /// compile-time known to be valid (odd, matching length).
68    fn from_hardcoded(kernel: Vec<f32>, width: u32, height: u32, normalize: bool) -> Self {
69        debug_assert_eq!(kernel.len(), (width * height) as usize);
70        debug_assert!(width % 2 == 1 && height % 2 == 1);
71        Self {
72            kernel,
73            width,
74            height,
75            normalize,
76        }
77    }
78
79    /// Create a Sobel X kernel (3x3)
80    #[must_use]
81    pub fn sobel_x() -> Self {
82        Self::from_hardcoded(
83            vec![-1.0, 0.0, 1.0, -2.0, 0.0, 2.0, -1.0, 0.0, 1.0],
84            3,
85            3,
86            false,
87        )
88    }
89
90    /// Create a Sobel Y kernel (3x3)
91    #[must_use]
92    pub fn sobel_y() -> Self {
93        Self::from_hardcoded(
94            vec![-1.0, -2.0, -1.0, 0.0, 0.0, 0.0, 1.0, 2.0, 1.0],
95            3,
96            3,
97            false,
98        )
99    }
100
101    /// Create a Laplacian kernel (3x3)
102    #[must_use]
103    pub fn laplacian() -> Self {
104        Self::from_hardcoded(
105            vec![0.0, 1.0, 0.0, 1.0, -4.0, 1.0, 0.0, 1.0, 0.0],
106            3,
107            3,
108            false,
109        )
110    }
111
112    /// Create a box blur kernel
113    pub fn box_blur(size: u32) -> Result<Self> {
114        let total = (size * size) as usize;
115        let value = 1.0 / total as f32;
116        let kernel = vec![value; total];
117        Self::new(kernel, size, size, false)
118    }
119
120    /// Create a sharpening kernel
121    #[must_use]
122    pub fn sharpen() -> Self {
123        Self::from_hardcoded(
124            vec![0.0, -1.0, 0.0, -1.0, 5.0, -1.0, 0.0, -1.0, 0.0],
125            3,
126            3,
127            false,
128        )
129    }
130
131    /// Get the kernel data
132    #[must_use]
133    pub fn data(&self) -> &[f32] {
134        &self.kernel
135    }
136
137    /// Get kernel dimensions
138    #[must_use]
139    pub fn dimensions(&self) -> (u32, u32) {
140        (self.width, self.height)
141    }
142
143    /// Check if normalization is enabled
144    #[must_use]
145    pub fn is_normalized(&self) -> bool {
146        self.normalize
147    }
148
149    /// Apply the convolution kernel
150    ///
151    /// # Arguments
152    ///
153    /// * `device` - GPU device
154    /// * `input` - Input image buffer
155    /// * `output` - Output image buffer
156    /// * `width` - Image width
157    /// * `height` - Image height
158    ///
159    /// # Errors
160    ///
161    /// Returns an error if the operation fails.
162    pub fn apply(
163        &self,
164        device: &GpuDevice,
165        input: &[u8],
166        output: &mut [u8],
167        width: u32,
168        height: u32,
169    ) -> Result<()> {
170        crate::ops::FilterOperation::convolve(
171            device,
172            input,
173            output,
174            width,
175            height,
176            &self.kernel,
177            self.normalize,
178        )
179    }
180}
181
182/// Image filter kernel
183pub struct FilterKernel {
184    filter_type: FilterType,
185    sigma: f32,
186    kernel_size: u32,
187}
188
189impl FilterKernel {
190    /// Create a new filter kernel
191    #[must_use]
192    pub fn new(filter_type: FilterType, sigma: f32, kernel_size: u32) -> Self {
193        Self {
194            filter_type,
195            sigma,
196            kernel_size,
197        }
198    }
199
200    /// Create a Gaussian blur filter
201    #[must_use]
202    pub fn gaussian_blur(sigma: f32) -> Self {
203        let kernel_size = Self::gaussian_kernel_size(sigma);
204        Self::new(FilterType::GaussianBlur, sigma, kernel_size)
205    }
206
207    /// Create a box blur filter
208    #[must_use]
209    pub fn box_blur(radius: u32) -> Self {
210        let kernel_size = radius * 2 + 1;
211        Self::new(FilterType::BoxBlur, 0.0, kernel_size)
212    }
213
214    /// Create an unsharp mask filter (sharpening)
215    #[must_use]
216    pub fn sharpen(amount: f32) -> Self {
217        Self::new(FilterType::UnsharpMask, amount, 5)
218    }
219
220    /// Create a Sobel edge detection filter
221    #[must_use]
222    pub fn sobel() -> Self {
223        Self::new(FilterType::Sobel, 0.0, 3)
224    }
225
226    /// Create a bilateral filter
227    #[must_use]
228    pub fn bilateral(sigma_spatial: f32, _sigma_range: f32) -> Self {
229        let kernel_size = Self::gaussian_kernel_size(sigma_spatial);
230        Self::new(FilterType::Bilateral, sigma_spatial, kernel_size)
231    }
232
233    /// Execute the filter operation
234    ///
235    /// # Arguments
236    ///
237    /// * `device` - GPU device
238    /// * `input` - Input image buffer
239    /// * `output` - Output image buffer
240    /// * `width` - Image width
241    /// * `height` - Image height
242    ///
243    /// # Errors
244    ///
245    /// Returns an error if the operation fails.
246    pub fn execute(
247        &self,
248        device: &GpuDevice,
249        input: &[u8],
250        output: &mut [u8],
251        width: u32,
252        height: u32,
253    ) -> Result<()> {
254        match self.filter_type {
255            FilterType::GaussianBlur => crate::ops::FilterOperation::gaussian_blur(
256                device, input, output, width, height, self.sigma,
257            ),
258            FilterType::UnsharpMask => crate::ops::FilterOperation::sharpen(
259                device, input, output, width, height, self.sigma,
260            ),
261            FilterType::Sobel | FilterType::Scharr | FilterType::Laplacian => {
262                crate::ops::FilterOperation::edge_detect(device, input, output, width, height)
263            }
264            _ => Err(crate::GpuError::NotSupported(format!(
265                "Filter type {:?} not yet implemented",
266                self.filter_type
267            ))),
268        }
269    }
270
271    /// Get the filter type
272    #[must_use]
273    pub fn filter_type(&self) -> FilterType {
274        self.filter_type
275    }
276
277    /// Get the sigma parameter
278    #[must_use]
279    pub fn sigma(&self) -> f32 {
280        self.sigma
281    }
282
283    /// Get the kernel size
284    #[must_use]
285    pub fn kernel_size(&self) -> u32 {
286        self.kernel_size
287    }
288
289    /// Calculate Gaussian kernel size from sigma
290    fn gaussian_kernel_size(sigma: f32) -> u32 {
291        let radius = (3.0 * sigma).ceil() as u32;
292        2 * radius + 1
293    }
294
295    /// Estimate FLOPS for filter operation
296    #[must_use]
297    pub fn estimate_flops(width: u32, height: u32, kernel_size: u32) -> u64 {
298        let pixels = u64::from(width) * u64::from(height);
299        let ops_per_pixel = u64::from(kernel_size) * u64::from(kernel_size) * 4; // 4 channels
300        pixels * ops_per_pixel * 2 // multiply + add
301    }
302}
303
304/// Separable filter for optimized 2D filtering
305pub struct SeparableFilter {
306    horizontal_kernel: Vec<f32>,
307    vertical_kernel: Vec<f32>,
308}
309
310impl SeparableFilter {
311    /// Create a new separable filter
312    #[must_use]
313    pub fn new(horizontal: Vec<f32>, vertical: Vec<f32>) -> Self {
314        Self {
315            horizontal_kernel: horizontal,
316            vertical_kernel: vertical,
317        }
318    }
319
320    /// Create a Gaussian separable filter
321    #[must_use]
322    pub fn gaussian(sigma: f32, size: u32) -> Self {
323        let kernel = Self::gaussian_kernel_1d(sigma, size);
324        Self::new(kernel.clone(), kernel)
325    }
326
327    /// Generate 1D Gaussian kernel
328    fn gaussian_kernel_1d(sigma: f32, size: u32) -> Vec<f32> {
329        let radius = (size / 2) as i32;
330        let mut kernel = Vec::with_capacity(size as usize);
331        let two_sigma_sq = 2.0 * sigma * sigma;
332
333        let mut sum = 0.0;
334        for i in -radius..=radius {
335            let value = (-((i * i) as f32) / two_sigma_sq).exp();
336            kernel.push(value);
337            sum += value;
338        }
339
340        // Normalize
341        for value in &mut kernel {
342            *value /= sum;
343        }
344
345        kernel
346    }
347
348    /// Get the horizontal kernel
349    #[must_use]
350    pub fn horizontal(&self) -> &[f32] {
351        &self.horizontal_kernel
352    }
353
354    /// Get the vertical kernel
355    #[must_use]
356    pub fn vertical(&self) -> &[f32] {
357        &self.vertical_kernel
358    }
359}
360
361#[cfg(test)]
362mod tests {
363    use super::*;
364
365    #[test]
366    fn test_convolution_kernel_creation() {
367        let kernel = ConvolutionKernel::sobel_x();
368        assert_eq!(kernel.dimensions(), (3, 3));
369        assert_eq!(kernel.data().len(), 9);
370
371        let kernel = ConvolutionKernel::laplacian();
372        assert_eq!(kernel.dimensions(), (3, 3));
373    }
374
375    #[test]
376    fn test_filter_kernel_creation() {
377        let filter = FilterKernel::gaussian_blur(2.0);
378        assert_eq!(filter.filter_type(), FilterType::GaussianBlur);
379        assert_eq!(filter.sigma(), 2.0);
380
381        let filter = FilterKernel::sobel();
382        assert_eq!(filter.filter_type(), FilterType::Sobel);
383        assert_eq!(filter.kernel_size(), 3);
384    }
385
386    #[test]
387    fn test_separable_filter() {
388        let filter = SeparableFilter::gaussian(1.0, 5);
389        assert_eq!(filter.horizontal().len(), 5);
390        assert_eq!(filter.vertical().len(), 5);
391
392        // Check normalization
393        let sum: f32 = filter.horizontal().iter().sum();
394        assert!((sum - 1.0).abs() < 0.001);
395    }
396
397    #[test]
398    fn test_box_blur_kernel() {
399        let kernel =
400            ConvolutionKernel::box_blur(3).expect("box blur kernel creation should succeed");
401        assert_eq!(kernel.dimensions(), (3, 3));
402        let expected_value = 1.0 / 9.0;
403        for &value in kernel.data() {
404            assert!((value - expected_value).abs() < 0.001);
405        }
406    }
407}