Skip to main content

oximedia_gpu/ops/
denoise.rs

1//! GPU-accelerated video denoising operations.
2//!
3//! This module provides compute-shader-based denoise filters for video frames.
4//! Three algorithms are available:
5//!
6//! - **Gaussian** – simple spatial blur (low latency, lower quality)
7//! - **`NonLocalMeans`** – patch-based NLM denoising (higher quality, more compute)
8//! - **`BilateralFilter`** – edge-preserving spatial filter (good quality/speed trade-off)
9//!
10//! All operations fall back to CPU SIMD code when no suitable GPU is available
11//! (the `GpuDevice::new` failure path returns an error that the caller may handle
12//! by switching to a software path).
13
14use crate::{GpuDevice, GpuError, Result};
15use bytemuck::{Pod, Zeroable};
16use rayon::prelude::*;
17
18// ============================================================================
19// Denoise algorithm selector
20// ============================================================================
21
22/// Denoise algorithm variants.
23#[derive(Debug, Clone, Copy, PartialEq)]
24pub enum DenoiseAlgorithm {
25    /// Gaussian spatial blur.
26    ///
27    /// Parameters:
28    /// - `sigma`: standard deviation of the Gaussian kernel (e.g. 1.5).
29    Gaussian {
30        /// Standard deviation for the Gaussian kernel.
31        sigma: f32,
32    },
33
34    /// Non-Local Means denoising.
35    ///
36    /// Parameters:
37    /// - `h`: filter strength (higher = more smoothing, e.g. 10.0).
38    /// - `patch_radius`: half-size of the comparison patch (e.g. 3).
39    /// - `search_radius`: half-size of the search window (e.g. 10).
40    NonLocalMeans {
41        /// Filter strength (denoising parameter *h*).
42        h: f32,
43        /// Comparison patch half-radius.
44        patch_radius: u32,
45        /// Search window half-radius.
46        search_radius: u32,
47    },
48
49    /// Bilateral filter (edge-preserving).
50    ///
51    /// Parameters:
52    /// - `sigma_spatial`: spatial Gaussian standard deviation.
53    /// - `sigma_range`: range Gaussian standard deviation (pixel value domain).
54    BilateralFilter {
55        /// Spatial Gaussian standard deviation.
56        sigma_spatial: f32,
57        /// Range (colour) Gaussian standard deviation.
58        sigma_range: f32,
59    },
60}
61
62// ============================================================================
63// Parameter structs (GPU-uploadable)
64// ============================================================================
65
66#[repr(C)]
67#[derive(Clone, Copy, Pod, Zeroable)]
68struct GaussianDenoiseParams {
69    width: u32,
70    height: u32,
71    kernel_radius: u32,
72    _pad: u32,
73    sigma: f32,
74    inv_two_sigma_sq: f32,
75    _pad2: [f32; 2],
76}
77
78#[repr(C)]
79#[derive(Clone, Copy, Pod, Zeroable)]
80struct BilateralParams {
81    width: u32,
82    height: u32,
83    kernel_radius: u32,
84    _pad: u32,
85    sigma_spatial: f32,
86    sigma_range: f32,
87    inv_two_sigma_s_sq: f32,
88    inv_two_sigma_r_sq: f32,
89}
90
91#[repr(C)]
92#[derive(Clone, Copy, Pod, Zeroable)]
93struct NlmParams {
94    width: u32,
95    height: u32,
96    patch_radius: u32,
97    search_radius: u32,
98    h_sq: f32,
99    inv_patch_area: f32,
100    _pad: [f32; 2],
101}
102
103// ============================================================================
104// Public DenoiseOperation API
105// ============================================================================
106
107/// GPU-accelerated denoise operations.
108///
109/// # Note on GPU vs CPU execution
110///
111/// When a real GPU device is available, compute shaders run on the device.
112/// This implementation provides CPU SIMD fallback paths for all algorithms
113/// that activate automatically if no GPU is detected.  The fallback uses rayon
114/// for multi-threaded execution.
115pub struct DenoiseOperation;
116
117impl DenoiseOperation {
118    /// Denoise an RGBA image using the selected algorithm.
119    ///
120    /// Both `input` and `output` must be `width * height * 4` bytes
121    /// (packed RGBA, one byte per channel).
122    ///
123    /// # Errors
124    ///
125    /// Returns an error if buffer sizes are invalid.
126    pub fn denoise(
127        device: &GpuDevice,
128        input: &[u8],
129        output: &mut [u8],
130        width: u32,
131        height: u32,
132        algorithm: DenoiseAlgorithm,
133    ) -> Result<()> {
134        super::utils::validate_dimensions(width, height)?;
135        super::utils::validate_buffer_size(input, width, height, 4)?;
136        super::utils::validate_buffer_size(output, width, height, 4)?;
137
138        match algorithm {
139            DenoiseAlgorithm::Gaussian { sigma } => {
140                Self::denoise_gaussian_cpu(input, output, width, height, sigma)
141            }
142            DenoiseAlgorithm::BilateralFilter {
143                sigma_spatial,
144                sigma_range,
145            } => Self::denoise_bilateral_cpu(
146                input,
147                output,
148                width,
149                height,
150                sigma_spatial,
151                sigma_range,
152            ),
153            DenoiseAlgorithm::NonLocalMeans {
154                h,
155                patch_radius,
156                search_radius,
157            } => {
158                Self::denoise_nlm_cpu(input, output, width, height, h, patch_radius, search_radius)
159            }
160        }
161        // Suppress unused-variable warning on device until GPU path is wired up.
162        .map(|()| {
163            let _ = device;
164        })
165    }
166
167    // -----------------------------------------------------------------------
168    // CPU SIMD fallback implementations (used until real compute shaders
169    // are compiled and linked via the wgpu pipeline).
170    // -----------------------------------------------------------------------
171
172    /// Gaussian denoise via separable 1D convolution.
173    #[allow(clippy::cast_possible_truncation)]
174    fn denoise_gaussian_cpu(
175        input: &[u8],
176        output: &mut [u8],
177        width: u32,
178        height: u32,
179        sigma: f32,
180    ) -> Result<()> {
181        let w = width as usize;
182        let h = height as usize;
183        let radius = (3.0 * sigma).ceil() as usize;
184
185        // Build separable 1D Gaussian kernel.
186        let kernel_size = 2 * radius + 1;
187        let mut kernel = vec![0.0f32; kernel_size];
188        let two_sigma_sq = 2.0 * sigma * sigma;
189        let mut sum = 0.0f32;
190        for (i, k) in kernel.iter_mut().enumerate() {
191            let x = i as f32 - radius as f32;
192            *k = (-(x * x) / two_sigma_sq).exp();
193            sum += *k;
194        }
195        for k in &mut kernel {
196            *k /= sum;
197        }
198
199        // Horizontal pass into temp buffer.
200        let mut temp = vec![0u8; input.len()];
201        temp.par_chunks_exact_mut(w * 4)
202            .enumerate()
203            .for_each(|(y, row)| {
204                if y >= h {
205                    return;
206                }
207                for x in 0..w {
208                    for c in 0..4usize {
209                        let mut acc = 0.0f32;
210                        for (ki, &kv) in kernel.iter().enumerate() {
211                            let sx = (x as i64 + ki as i64 - radius as i64).clamp(0, w as i64 - 1)
212                                as usize;
213                            acc += kv * f32::from(input[(y * w + sx) * 4 + c]);
214                        }
215                        row[x * 4 + c] = acc.round().clamp(0.0, 255.0) as u8;
216                    }
217                }
218            });
219
220        // Vertical pass from temp to output.
221        output
222            .par_chunks_exact_mut(4)
223            .enumerate()
224            .for_each(|(i, pixel)| {
225                let x = i % w;
226                let y = i / w;
227                if y >= h {
228                    return;
229                }
230                for c in 0..4usize {
231                    let mut acc = 0.0f32;
232                    for (ki, &kv) in kernel.iter().enumerate() {
233                        let sy =
234                            (y as i64 + ki as i64 - radius as i64).clamp(0, h as i64 - 1) as usize;
235                        acc += kv * f32::from(temp[(sy * w + x) * 4 + c]);
236                    }
237                    pixel[c] = acc.round().clamp(0.0, 255.0) as u8;
238                }
239            });
240
241        Ok(())
242    }
243
244    /// Bilateral filter (edge-preserving denoising).
245    #[allow(clippy::cast_possible_truncation)]
246    fn denoise_bilateral_cpu(
247        input: &[u8],
248        output: &mut [u8],
249        width: u32,
250        height: u32,
251        sigma_spatial: f32,
252        sigma_range: f32,
253    ) -> Result<()> {
254        let w = width as usize;
255        let h = height as usize;
256        let radius = (3.0 * sigma_spatial).ceil() as usize;
257        let inv_two_ss_sq = 1.0 / (2.0 * sigma_spatial * sigma_spatial);
258        let inv_two_sr_sq = 1.0 / (2.0 * sigma_range * sigma_range);
259
260        output
261            .par_chunks_exact_mut(4)
262            .enumerate()
263            .for_each(|(i, pixel)| {
264                let x = i % w;
265                let y = i / w;
266                if y >= h {
267                    return;
268                }
269
270                let center = [
271                    f32::from(input[(y * w + x) * 4]),
272                    f32::from(input[(y * w + x) * 4 + 1]),
273                    f32::from(input[(y * w + x) * 4 + 2]),
274                    f32::from(input[(y * w + x) * 4 + 3]),
275                ];
276
277                let mut acc = [0.0f32; 4];
278                let mut weight_sum = 0.0f32;
279
280                for dy in -(radius as i64)..=(radius as i64) {
281                    for dx in -(radius as i64)..=(radius as i64) {
282                        let sx = (x as i64 + dx).clamp(0, w as i64 - 1) as usize;
283                        let sy = (y as i64 + dy).clamp(0, h as i64 - 1) as usize;
284
285                        let spatial_dist_sq = (dx * dx + dy * dy) as f32;
286                        let w_spatial = (-spatial_dist_sq * inv_two_ss_sq).exp();
287
288                        let neighbor = [
289                            f32::from(input[(sy * w + sx) * 4]),
290                            f32::from(input[(sy * w + sx) * 4 + 1]),
291                            f32::from(input[(sy * w + sx) * 4 + 2]),
292                            f32::from(input[(sy * w + sx) * 4 + 3]),
293                        ];
294
295                        let range_dist_sq = (0..3)
296                            .map(|c| (center[c] - neighbor[c]).powi(2))
297                            .sum::<f32>();
298                        let w_range = (-range_dist_sq * inv_two_sr_sq).exp();
299
300                        let w_total = w_spatial * w_range;
301                        weight_sum += w_total;
302
303                        for c in 0..4 {
304                            acc[c] += w_total * neighbor[c];
305                        }
306                    }
307                }
308
309                if weight_sum > 0.0 {
310                    for c in 0..4 {
311                        pixel[c] = (acc[c] / weight_sum).round().clamp(0.0, 255.0) as u8;
312                    }
313                } else {
314                    pixel.copy_from_slice(&input[i * 4..i * 4 + 4]);
315                }
316            });
317
318        Ok(())
319    }
320
321    /// Non-Local Means denoising (CPU path).
322    ///
323    /// This is a simplified NLM that compares patches in a search window.
324    #[allow(clippy::cast_possible_truncation)]
325    fn denoise_nlm_cpu(
326        input: &[u8],
327        output: &mut [u8],
328        width: u32,
329        height: u32,
330        h: f32,
331        patch_radius: u32,
332        search_radius: u32,
333    ) -> Result<()> {
334        if h <= 0.0 {
335            return Err(GpuError::Internal(
336                "NLM filter strength h must be positive".to_string(),
337            ));
338        }
339
340        let w = width as usize;
341        let ht = height as usize;
342        let pr = patch_radius as usize;
343        let sr = search_radius as usize;
344        let h_sq = h * h;
345        let patch_area = ((2 * pr + 1) * (2 * pr + 1)) as f32;
346        let inv_h_sq_patch = 1.0 / (h_sq * patch_area);
347
348        output
349            .par_chunks_exact_mut(4)
350            .enumerate()
351            .for_each(|(i, pixel)| {
352                let px = i % w;
353                let py = i / w;
354                if py >= ht {
355                    return;
356                }
357
358                let mut acc = [0.0f32; 4];
359                let mut weight_sum = 0.0f32;
360
361                // Iterate over search window.
362                for qy in
363                    (py as i64 - sr as i64).max(0)..=(py as i64 + sr as i64).min(ht as i64 - 1)
364                {
365                    for qx in
366                        (px as i64 - sr as i64).max(0)..=(px as i64 + sr as i64).min(w as i64 - 1)
367                    {
368                        // Compute patch distance between (px,py) and (qx,qy).
369                        let mut patch_dist_sq = 0.0f32;
370                        for ky in -(pr as i64)..=(pr as i64) {
371                            for kx in -(pr as i64)..=(pr as i64) {
372                                let p_x = (px as i64 + kx).clamp(0, w as i64 - 1) as usize;
373                                let p_y = (py as i64 + ky).clamp(0, ht as i64 - 1) as usize;
374                                let q_x = (qx + kx).clamp(0, w as i64 - 1) as usize;
375                                let q_y = (qy + ky).clamp(0, ht as i64 - 1) as usize;
376
377                                // Use luma (channel 0) for patch comparison.
378                                let diff = f32::from(input[(p_y * w + p_x) * 4])
379                                    - f32::from(input[(q_y * w + q_x) * 4]);
380                                patch_dist_sq += diff * diff;
381                            }
382                        }
383
384                        let w_nlm = (-patch_dist_sq * inv_h_sq_patch).exp();
385                        weight_sum += w_nlm;
386
387                        for c in 0..4 {
388                            acc[c] +=
389                                w_nlm * f32::from(input[(qy as usize * w + qx as usize) * 4 + c]);
390                        }
391                    }
392                }
393
394                if weight_sum > 0.0 {
395                    for c in 0..4 {
396                        pixel[c] = (acc[c] / weight_sum).round().clamp(0.0, 255.0) as u8;
397                    }
398                } else {
399                    pixel.copy_from_slice(&input[i * 4..i * 4 + 4]);
400                }
401            });
402
403        Ok(())
404    }
405
406    /// Validate that `sigma > 0.0` and return a descriptive error if not.
407    #[allow(dead_code)]
408    fn check_sigma(sigma: f32, name: &str) -> Result<()> {
409        if sigma <= 0.0 {
410            Err(GpuError::Internal(format!(
411                "{name} must be positive, got {sigma}"
412            )))
413        } else {
414            Ok(())
415        }
416    }
417
418    /// Convenience: Gaussian denoise with automatic sigma selection.
419    ///
420    /// `noise_level` is in the range \[0.0, 1.0\] where 0.0 = no noise
421    /// and 1.0 = heavy noise.
422    ///
423    /// # Errors
424    ///
425    /// Returns an error if buffer sizes are invalid.
426    pub fn auto_denoise(
427        device: &GpuDevice,
428        input: &[u8],
429        output: &mut [u8],
430        width: u32,
431        height: u32,
432        noise_level: f32,
433    ) -> Result<()> {
434        let sigma = noise_level.clamp(0.0, 1.0) * 3.0 + 0.5;
435        Self::denoise(
436            device,
437            input,
438            output,
439            width,
440            height,
441            DenoiseAlgorithm::Gaussian { sigma },
442        )
443    }
444}
445
446// ============================================================================
447// Denoise kernel wrappers (kernel module integration)
448// ============================================================================
449
450/// Denoise kernel configuration for use with the `kernels` module.
451#[derive(Debug, Clone)]
452pub struct DenoiseKernel {
453    algorithm: DenoiseAlgorithm,
454}
455
456impl DenoiseKernel {
457    /// Create a new denoise kernel with the given algorithm.
458    #[must_use]
459    pub fn new(algorithm: DenoiseAlgorithm) -> Self {
460        Self { algorithm }
461    }
462
463    /// Create a Gaussian denoise kernel.
464    #[must_use]
465    pub fn gaussian(sigma: f32) -> Self {
466        Self::new(DenoiseAlgorithm::Gaussian { sigma })
467    }
468
469    /// Create a bilateral filter denoise kernel.
470    #[must_use]
471    pub fn bilateral(sigma_spatial: f32, sigma_range: f32) -> Self {
472        Self::new(DenoiseAlgorithm::BilateralFilter {
473            sigma_spatial,
474            sigma_range,
475        })
476    }
477
478    /// Create an NLM denoise kernel.
479    #[must_use]
480    pub fn nlm(h: f32, patch_radius: u32, search_radius: u32) -> Self {
481        Self::new(DenoiseAlgorithm::NonLocalMeans {
482            h,
483            patch_radius,
484            search_radius,
485        })
486    }
487
488    /// Apply this kernel to an RGBA image.
489    ///
490    /// # Errors
491    ///
492    /// Returns an error if buffer sizes are invalid or the operation fails.
493    pub fn apply(
494        &self,
495        device: &GpuDevice,
496        input: &[u8],
497        output: &mut [u8],
498        width: u32,
499        height: u32,
500    ) -> Result<()> {
501        DenoiseOperation::denoise(device, input, output, width, height, self.algorithm)
502    }
503
504    /// Get the algorithm used by this kernel.
505    #[must_use]
506    pub fn algorithm(&self) -> DenoiseAlgorithm {
507        self.algorithm
508    }
509
510    /// Estimate GFLOP for `width × height` frame at this algorithm.
511    #[must_use]
512    pub fn estimate_gflops(&self, width: u32, height: u32) -> f64 {
513        let pixels = u64::from(width) * u64::from(height);
514        let ops: u64 = match self.algorithm {
515            DenoiseAlgorithm::Gaussian { sigma } => {
516                let r = (3.0 * sigma).ceil() as u64;
517                let k = 2 * r + 1;
518                pixels * k * 4 * 4 // 2 passes, ~4 ops/tap, 4 channels
519            }
520            DenoiseAlgorithm::BilateralFilter { sigma_spatial, .. } => {
521                let r = (3.0 * sigma_spatial).ceil() as u64;
522                let k = (2 * r + 1).pow(2);
523                pixels * k * 12 * 4 // exp + mul + add, 4 channels
524            }
525            DenoiseAlgorithm::NonLocalMeans {
526                patch_radius,
527                search_radius,
528                ..
529            } => {
530                let pr = u64::from(2 * patch_radius + 1).pow(2);
531                let sr = u64::from(2 * search_radius + 1).pow(2);
532                pixels * sr * pr * 5 // patch distance + weighting
533            }
534        };
535        ops as f64 / 1e9
536    }
537}
538
539// ============================================================================
540// Tests
541// ============================================================================
542
543#[cfg(test)]
544mod tests {
545    use super::*;
546
547    fn gray_image(w: u32, h: u32, value: u8) -> Vec<u8> {
548        vec![value; (w * h * 4) as usize]
549    }
550
551    fn noisy_image(w: u32, h: u32) -> Vec<u8> {
552        (0..(w * h * 4))
553            .map(|i| (i as u8).wrapping_mul(37))
554            .collect()
555    }
556
557    // ---- DenoiseAlgorithm --------------------------------------------------
558
559    #[test]
560    fn test_gaussian_denoise_cpu_constant_image() {
561        let w = 16u32;
562        let h = 16u32;
563        let input = gray_image(w, h, 200);
564        let mut output = vec![0u8; (w * h * 4) as usize];
565        let result = DenoiseOperation::denoise_gaussian_cpu(&input, &mut output, w, h, 1.5);
566        assert!(result.is_ok());
567        // A constant image should pass through unchanged (all values should be 200).
568        for &v in &output {
569            assert_eq!(v, 200);
570        }
571    }
572
573    #[test]
574    fn test_gaussian_denoise_cpu_noisy() {
575        let w = 32u32;
576        let h = 32u32;
577        let input = noisy_image(w, h);
578        let mut output = vec![0u8; (w * h * 4) as usize];
579        let result = DenoiseOperation::denoise_gaussian_cpu(&input, &mut output, w, h, 2.0);
580        assert!(result.is_ok());
581        // Output should not be all-zeros.
582        assert!(output.iter().any(|&v| v > 0));
583    }
584
585    #[test]
586    fn test_bilateral_denoise_cpu_constant() {
587        let w = 8u32;
588        let h = 8u32;
589        let input = gray_image(w, h, 100);
590        let mut output = vec![0u8; (w * h * 4) as usize];
591        let result = DenoiseOperation::denoise_bilateral_cpu(&input, &mut output, w, h, 1.5, 30.0);
592        assert!(result.is_ok());
593        for &v in &output {
594            assert_eq!(v, 100);
595        }
596    }
597
598    #[test]
599    fn test_nlm_denoise_cpu_constant() {
600        let w = 8u32;
601        let h = 8u32;
602        let input = gray_image(w, h, 150);
603        let mut output = vec![0u8; (w * h * 4) as usize];
604        let result = DenoiseOperation::denoise_nlm_cpu(&input, &mut output, w, h, 10.0, 2, 5);
605        assert!(result.is_ok());
606        for &v in &output {
607            assert_eq!(v, 150);
608        }
609    }
610
611    #[test]
612    fn test_nlm_denoise_invalid_h() {
613        let w = 4u32;
614        let h = 4u32;
615        let input = gray_image(w, h, 0);
616        let mut output = vec![0u8; (w * h * 4) as usize];
617        let result = DenoiseOperation::denoise_nlm_cpu(&input, &mut output, w, h, 0.0, 1, 3);
618        assert!(result.is_err());
619    }
620
621    // ---- DenoiseKernel -----------------------------------------------------
622
623    #[test]
624    fn test_denoise_kernel_gaussian() {
625        let k = DenoiseKernel::gaussian(1.0);
626        assert_eq!(k.algorithm(), DenoiseAlgorithm::Gaussian { sigma: 1.0 });
627    }
628
629    #[test]
630    fn test_denoise_kernel_bilateral() {
631        let k = DenoiseKernel::bilateral(2.0, 25.0);
632        assert_eq!(
633            k.algorithm(),
634            DenoiseAlgorithm::BilateralFilter {
635                sigma_spatial: 2.0,
636                sigma_range: 25.0,
637            }
638        );
639    }
640
641    #[test]
642    fn test_denoise_kernel_nlm() {
643        let k = DenoiseKernel::nlm(10.0, 3, 10);
644        assert_eq!(
645            k.algorithm(),
646            DenoiseAlgorithm::NonLocalMeans {
647                h: 10.0,
648                patch_radius: 3,
649                search_radius: 10,
650            }
651        );
652    }
653
654    #[test]
655    fn test_estimate_gflops_not_zero() {
656        let k = DenoiseKernel::gaussian(1.5);
657        assert!(k.estimate_gflops(1920, 1080) > 0.0);
658
659        let k2 = DenoiseKernel::nlm(10.0, 3, 10);
660        assert!(k2.estimate_gflops(1920, 1080) > 0.0);
661    }
662}