Skip to main content

oximedia_gpu/ops/
denoise.rs

1//! GPU-accelerated video denoising operations.
2//!
3//! This module provides compute-shader-based denoise filters for video frames.
4//! Three algorithms are available:
5//!
6//! - **Gaussian** – simple spatial blur (low latency, lower quality)
7//! - **`NonLocalMeans`** – patch-based NLM denoising (higher quality, more compute)
8//! - **`BilateralFilter`** – edge-preserving spatial filter (good quality/speed trade-off)
9//!
10//! All operations fall back to CPU SIMD code when no suitable GPU is available
11//! (the `GpuDevice::new` failure path returns an error that the caller may handle
12//! by switching to a software path).
13//!
14//! ## GPU path (bilateral filter)
15//!
16//! When `device.is_fallback == false` the bilateral filter is executed on the GPU
17//! via a WGSL compute shader (`shaders/bilateral.wgsl`).  The NLM algorithm is
18//! too expensive for a single-dispatch compute shader (O(search_area² × patch_area)
19//! per pixel) and is therefore kept as a CPU fallback; GPU NLM is planned for a
20//! future release using multi-pass tiled reduction.
21
22use crate::{
23    shader::{BindGroupLayoutBuilder, ShaderCompiler, ShaderSource},
24    GpuDevice, GpuError, Result,
25};
26use bytemuck::{Pod, Zeroable};
27use once_cell::sync::OnceCell;
28use rayon::prelude::*;
29use wgpu::{BindGroupLayout, ComputePipeline};
30
31// ============================================================================
32// Denoise algorithm selector
33// ============================================================================
34
35/// Denoise algorithm variants.
36#[derive(Debug, Clone, Copy, PartialEq)]
37pub enum DenoiseAlgorithm {
38    /// Gaussian spatial blur.
39    ///
40    /// Parameters:
41    /// - `sigma`: standard deviation of the Gaussian kernel (e.g. 1.5).
42    Gaussian {
43        /// Standard deviation for the Gaussian kernel.
44        sigma: f32,
45    },
46
47    /// Non-Local Means denoising.
48    ///
49    /// Parameters:
50    /// - `h`: filter strength (higher = more smoothing, e.g. 10.0).
51    /// - `patch_radius`: half-size of the comparison patch (e.g. 3).
52    /// - `search_radius`: half-size of the search window (e.g. 10).
53    NonLocalMeans {
54        /// Filter strength (denoising parameter *h*).
55        h: f32,
56        /// Comparison patch half-radius.
57        patch_radius: u32,
58        /// Search window half-radius.
59        search_radius: u32,
60    },
61
62    /// Bilateral filter (edge-preserving).
63    ///
64    /// Parameters:
65    /// - `sigma_spatial`: spatial Gaussian standard deviation.
66    /// - `sigma_range`: range Gaussian standard deviation (pixel value domain).
67    BilateralFilter {
68        /// Spatial Gaussian standard deviation.
69        sigma_spatial: f32,
70        /// Range (colour) Gaussian standard deviation.
71        sigma_range: f32,
72    },
73}
74
75// ============================================================================
76// Parameter structs (GPU-uploadable)
77// ============================================================================
78
79#[repr(C)]
80#[derive(Clone, Copy, Pod, Zeroable)]
81struct GaussianDenoiseParams {
82    width: u32,
83    height: u32,
84    kernel_radius: u32,
85    _pad: u32,
86    sigma: f32,
87    inv_two_sigma_sq: f32,
88    _pad2: [f32; 2],
89}
90
91/// GPU-side uniform layout (must match `BilateralParams` struct in bilateral.wgsl exactly).
92#[repr(C)]
93#[derive(Clone, Copy, Pod, Zeroable)]
94struct BilateralParams {
95    width: u32,
96    height: u32,
97    kernel_radius: u32,
98    _pad: u32,
99    sigma_spatial: f32,
100    sigma_range: f32,
101    inv_two_sigma_s_sq: f32,
102    inv_two_sigma_r_sq: f32,
103}
104
105#[repr(C)]
106#[derive(Clone, Copy, Pod, Zeroable)]
107struct NlmParams {
108    width: u32,
109    height: u32,
110    patch_radius: u32,
111    search_radius: u32,
112    h_sq: f32,
113    inv_patch_area: f32,
114    _pad: [f32; 2],
115}
116
117// ============================================================================
118// Public DenoiseOperation API
119// ============================================================================
120
121/// GPU-accelerated denoise operations.
122///
123/// # Note on GPU vs CPU execution
124///
125/// When a real GPU device is available (`!device.is_fallback`), the bilateral
126/// filter runs as a wgpu compute shader (`bilateral_filter_main` entry point in
127/// `shaders/bilateral.wgsl`).  If the GPU path fails at any point, execution
128/// transparently falls back to the CPU SIMD implementation.
129///
130/// Gaussian and NLM always use the CPU path (NLM GPU path is planned for a
131/// future release).
132pub struct DenoiseOperation;
133
134impl DenoiseOperation {
135    /// Denoise an RGBA image using the selected algorithm.
136    ///
137    /// Both `input` and `output` must be `width * height * 4` bytes
138    /// (packed RGBA, one byte per channel).
139    ///
140    /// # Errors
141    ///
142    /// Returns an error if buffer sizes are invalid.
143    pub fn denoise(
144        device: &GpuDevice,
145        input: &[u8],
146        output: &mut [u8],
147        width: u32,
148        height: u32,
149        algorithm: DenoiseAlgorithm,
150    ) -> Result<()> {
151        super::utils::validate_dimensions(width, height)?;
152        super::utils::validate_buffer_size(input, width, height, 4)?;
153        super::utils::validate_buffer_size(output, width, height, 4)?;
154
155        match algorithm {
156            DenoiseAlgorithm::Gaussian { sigma } => {
157                Self::denoise_gaussian_cpu(input, output, width, height, sigma)
158            }
159            DenoiseAlgorithm::BilateralFilter {
160                sigma_spatial,
161                sigma_range,
162            } => {
163                // Prefer GPU path when the device is a real (non-fallback) adapter.
164                if !device.is_fallback {
165                    match Self::denoise_bilateral_gpu(
166                        device,
167                        input,
168                        output,
169                        width,
170                        height,
171                        sigma_spatial,
172                        sigma_range,
173                    ) {
174                        Ok(()) => return Ok(()),
175                        Err(e) => {
176                            tracing::warn!(
177                                "GPU bilateral filter failed ({e}), falling back to CPU"
178                            );
179                        }
180                    }
181                }
182                // CPU fallback (also used for software adapters).
183                Self::denoise_bilateral_cpu(
184                    input,
185                    output,
186                    width,
187                    height,
188                    sigma_spatial,
189                    sigma_range,
190                )
191            }
192            // NLM GPU path is future work — CPU fallback only.
193            DenoiseAlgorithm::NonLocalMeans {
194                h,
195                patch_radius,
196                search_radius,
197            } => {
198                Self::denoise_nlm_cpu(input, output, width, height, h, patch_radius, search_radius)
199            }
200        }
201    }
202
203    // -----------------------------------------------------------------------
204    // GPU path — bilateral filter
205    // -----------------------------------------------------------------------
206
207    /// Run the bilateral filter on the GPU via wgpu compute shader.
208    #[allow(clippy::cast_possible_truncation)]
209    fn denoise_bilateral_gpu(
210        device: &GpuDevice,
211        input: &[u8],
212        output: &mut [u8],
213        width: u32,
214        height: u32,
215        sigma_spatial: f32,
216        sigma_range: f32,
217    ) -> Result<()> {
218        use super::utils::{
219            calculate_dispatch_size, create_readback_buffer, create_storage_buffer,
220            create_uniform_buffer,
221        };
222
223        let pipeline = Self::get_bilateral_pipeline(device)?;
224        let layout = Self::get_bilateral_bind_group_layout(device)?;
225
226        // Build uniform params buffer.
227        let kernel_radius = (3.0 * sigma_spatial).ceil() as u32;
228        let inv_two_ss_sq = 1.0 / (2.0 * sigma_spatial * sigma_spatial);
229        let inv_two_sr_sq = 1.0 / (2.0 * sigma_range * sigma_range);
230
231        let params = BilateralParams {
232            width,
233            height,
234            kernel_radius,
235            _pad: 0,
236            sigma_spatial,
237            sigma_range,
238            inv_two_sigma_s_sq: inv_two_ss_sq,
239            inv_two_sigma_r_sq: inv_two_sr_sq,
240        };
241
242        // Pack the u8 RGBA input as u32 words (shader reads array<u32>).
243        // Each u32 packs one RGBA pixel: R<<24 | G<<16 | B<<8 | A.
244        let num_pixels = (width * height) as usize;
245        let mut input_u32: Vec<u32> = Vec::with_capacity(num_pixels);
246        for chunk in input.chunks_exact(4) {
247            let packed = ((chunk[0] as u32) << 24)
248                | ((chunk[1] as u32) << 16)
249                | ((chunk[2] as u32) << 8)
250                | (chunk[3] as u32);
251            input_u32.push(packed);
252        }
253
254        let input_bytes = bytemuck::cast_slice(&input_u32);
255        let output_len = num_pixels * 4; // u32 per pixel, 4 bytes each
256
257        let input_buffer = create_storage_buffer(device, input_bytes.len() as u64)?;
258        let output_buffer = create_storage_buffer(device, output_len as u64)?;
259        let params_buffer = create_uniform_buffer(device, bytemuck::bytes_of(&params))?;
260
261        device
262            .queue()
263            .write_buffer(input_buffer.buffer(), 0, input_bytes);
264
265        // Build bind group.
266        let compiler = ShaderCompiler::new(device);
267        let bind_group = compiler.create_bind_group(
268            "Bilateral Bind Group",
269            layout,
270            &[
271                wgpu::BindGroupEntry {
272                    binding: 0,
273                    resource: input_buffer.buffer().as_entire_binding(),
274                },
275                wgpu::BindGroupEntry {
276                    binding: 1,
277                    resource: output_buffer.buffer().as_entire_binding(),
278                },
279                wgpu::BindGroupEntry {
280                    binding: 2,
281                    resource: params_buffer.buffer().as_entire_binding(),
282                },
283            ],
284        );
285
286        // Dispatch compute.
287        {
288            let mut encoder =
289                device
290                    .device()
291                    .create_command_encoder(&wgpu::CommandEncoderDescriptor {
292                        label: Some("Bilateral Compute Encoder"),
293                    });
294            {
295                let mut pass = encoder.begin_compute_pass(&wgpu::ComputePassDescriptor {
296                    label: Some("Bilateral Compute Pass"),
297                    timestamp_writes: None,
298                });
299                pass.set_pipeline(pipeline);
300                pass.set_bind_group(0, &bind_group, &[]);
301                let (dx, dy) = calculate_dispatch_size(width, height, (16, 16));
302                pass.dispatch_workgroups(dx, dy, 1);
303            }
304            device.queue().submit(Some(encoder.finish()));
305        }
306
307        // Copy output_buffer → readback buffer.
308        let readback = create_readback_buffer(device, output_len as u64)?;
309        {
310            let mut encoder =
311                device
312                    .device()
313                    .create_command_encoder(&wgpu::CommandEncoderDescriptor {
314                        label: Some("Bilateral Readback Encoder"),
315                    });
316            output_buffer.copy_to(&mut encoder, &readback, 0, 0, output_len as u64)?;
317            device.queue().submit(Some(encoder.finish()));
318        }
319
320        device.wait();
321
322        // Read back and unpack u32 → RGBA bytes.
323        let raw = readback.read(device, 0, output_len as u64)?;
324        let result_u32: &[u32] = bytemuck::cast_slice(&raw);
325        for (i, &packed) in result_u32.iter().enumerate() {
326            output[i * 4] = ((packed >> 24) & 0xFF) as u8;
327            output[i * 4 + 1] = ((packed >> 16) & 0xFF) as u8;
328            output[i * 4 + 2] = ((packed >> 8) & 0xFF) as u8;
329            output[i * 4 + 3] = (packed & 0xFF) as u8;
330        }
331
332        Ok(())
333    }
334
335    // -----------------------------------------------------------------------
336    // Pipeline management (cached per process lifetime via OnceCell)
337    // -----------------------------------------------------------------------
338
339    fn get_bilateral_bind_group_layout(device: &GpuDevice) -> Result<&'static BindGroupLayout> {
340        static LAYOUT: OnceCell<BindGroupLayout> = OnceCell::new();
341        Ok(LAYOUT.get_or_init(|| {
342            let compiler = ShaderCompiler::new(device);
343            let entries = BindGroupLayoutBuilder::new()
344                .add_storage_buffer_read_only(0) // input
345                .add_storage_buffer(1) // output
346                .add_uniform_buffer(2) // params
347                .build();
348            compiler.create_bind_group_layout("Bilateral Bind Group Layout", &entries)
349        }))
350    }
351
352    fn get_bilateral_pipeline(device: &GpuDevice) -> Result<&'static ComputePipeline> {
353        static PIPELINE: OnceCell<std::result::Result<ComputePipeline, String>> = OnceCell::new();
354        PIPELINE
355            .get_or_init(|| Self::init_bilateral_pipeline(device))
356            .as_ref()
357            .map_err(|e| GpuError::PipelineCreation(e.clone()))
358    }
359
360    fn init_bilateral_pipeline(device: &GpuDevice) -> std::result::Result<ComputePipeline, String> {
361        let compiler = ShaderCompiler::new(device);
362        let shader = compiler
363            .compile(
364                "Bilateral Shader",
365                ShaderSource::Embedded(crate::shader::embedded::BILATERAL_SHADER),
366            )
367            .map_err(|e| format!("Failed to compile bilateral shader: {e}"))?;
368
369        let layout = Self::get_bilateral_bind_group_layout(device)
370            .map_err(|e| format!("Failed to create bilateral bind group layout: {e}"))?;
371
372        compiler
373            .create_pipeline(
374                "Bilateral Pipeline",
375                &shader,
376                "bilateral_filter_main",
377                layout,
378            )
379            .map_err(|e| format!("Failed to create bilateral pipeline: {e}"))
380    }
381
382    // -----------------------------------------------------------------------
383    // CPU SIMD fallback implementations
384    // -----------------------------------------------------------------------
385
386    /// Gaussian denoise via separable 1D convolution.
387    #[allow(clippy::cast_possible_truncation)]
388    fn denoise_gaussian_cpu(
389        input: &[u8],
390        output: &mut [u8],
391        width: u32,
392        height: u32,
393        sigma: f32,
394    ) -> Result<()> {
395        let w = width as usize;
396        let h = height as usize;
397        let radius = (3.0 * sigma).ceil() as usize;
398
399        // Build separable 1D Gaussian kernel.
400        let kernel_size = 2 * radius + 1;
401        let mut kernel = vec![0.0f32; kernel_size];
402        let two_sigma_sq = 2.0 * sigma * sigma;
403        let mut sum = 0.0f32;
404        for (i, k) in kernel.iter_mut().enumerate() {
405            let x = i as f32 - radius as f32;
406            *k = (-(x * x) / two_sigma_sq).exp();
407            sum += *k;
408        }
409        for k in &mut kernel {
410            *k /= sum;
411        }
412
413        // Horizontal pass into temp buffer.
414        let mut temp = vec![0u8; input.len()];
415        temp.par_chunks_exact_mut(w * 4)
416            .enumerate()
417            .for_each(|(y, row)| {
418                if y >= h {
419                    return;
420                }
421                for x in 0..w {
422                    for c in 0..4usize {
423                        let mut acc = 0.0f32;
424                        for (ki, &kv) in kernel.iter().enumerate() {
425                            let sx = (x as i64 + ki as i64 - radius as i64).clamp(0, w as i64 - 1)
426                                as usize;
427                            acc += kv * f32::from(input[(y * w + sx) * 4 + c]);
428                        }
429                        row[x * 4 + c] = acc.round().clamp(0.0, 255.0) as u8;
430                    }
431                }
432            });
433
434        // Vertical pass from temp to output.
435        output
436            .par_chunks_exact_mut(4)
437            .enumerate()
438            .for_each(|(i, pixel)| {
439                let x = i % w;
440                let y = i / w;
441                if y >= h {
442                    return;
443                }
444                for c in 0..4usize {
445                    let mut acc = 0.0f32;
446                    for (ki, &kv) in kernel.iter().enumerate() {
447                        let sy =
448                            (y as i64 + ki as i64 - radius as i64).clamp(0, h as i64 - 1) as usize;
449                        acc += kv * f32::from(temp[(sy * w + x) * 4 + c]);
450                    }
451                    pixel[c] = acc.round().clamp(0.0, 255.0) as u8;
452                }
453            });
454
455        Ok(())
456    }
457
458    /// Bilateral filter — CPU fallback (edge-preserving denoising).
459    #[allow(clippy::cast_possible_truncation)]
460    fn denoise_bilateral_cpu(
461        input: &[u8],
462        output: &mut [u8],
463        width: u32,
464        height: u32,
465        sigma_spatial: f32,
466        sigma_range: f32,
467    ) -> Result<()> {
468        let w = width as usize;
469        let h = height as usize;
470        let radius = (3.0 * sigma_spatial).ceil() as usize;
471        let inv_two_ss_sq = 1.0 / (2.0 * sigma_spatial * sigma_spatial);
472        let inv_two_sr_sq = 1.0 / (2.0 * sigma_range * sigma_range);
473
474        output
475            .par_chunks_exact_mut(4)
476            .enumerate()
477            .for_each(|(i, pixel)| {
478                let x = i % w;
479                let y = i / w;
480                if y >= h {
481                    return;
482                }
483
484                let center = [
485                    f32::from(input[(y * w + x) * 4]),
486                    f32::from(input[(y * w + x) * 4 + 1]),
487                    f32::from(input[(y * w + x) * 4 + 2]),
488                    f32::from(input[(y * w + x) * 4 + 3]),
489                ];
490
491                let mut acc = [0.0f32; 4];
492                let mut weight_sum = 0.0f32;
493
494                for dy in -(radius as i64)..=(radius as i64) {
495                    for dx in -(radius as i64)..=(radius as i64) {
496                        let sx = (x as i64 + dx).clamp(0, w as i64 - 1) as usize;
497                        let sy = (y as i64 + dy).clamp(0, h as i64 - 1) as usize;
498
499                        let spatial_dist_sq = (dx * dx + dy * dy) as f32;
500                        let w_spatial = (-spatial_dist_sq * inv_two_ss_sq).exp();
501
502                        let neighbor = [
503                            f32::from(input[(sy * w + sx) * 4]),
504                            f32::from(input[(sy * w + sx) * 4 + 1]),
505                            f32::from(input[(sy * w + sx) * 4 + 2]),
506                            f32::from(input[(sy * w + sx) * 4 + 3]),
507                        ];
508
509                        let range_dist_sq = (0..3)
510                            .map(|c| (center[c] - neighbor[c]).powi(2))
511                            .sum::<f32>();
512                        let w_range = (-range_dist_sq * inv_two_sr_sq).exp();
513
514                        let w_total = w_spatial * w_range;
515                        weight_sum += w_total;
516
517                        for c in 0..4 {
518                            acc[c] += w_total * neighbor[c];
519                        }
520                    }
521                }
522
523                if weight_sum > 0.0 {
524                    for c in 0..4 {
525                        pixel[c] = (acc[c] / weight_sum).round().clamp(0.0, 255.0) as u8;
526                    }
527                } else {
528                    pixel.copy_from_slice(&input[i * 4..i * 4 + 4]);
529                }
530            });
531
532        Ok(())
533    }
534
535    /// Non-Local Means denoising (CPU path).
536    ///
537    /// GPU NLM is future work — the O(search_area² × patch_area) per-pixel cost
538    /// requires a multi-pass tiled reduction that does not map naturally to a
539    /// single compute dispatch.
540    #[allow(clippy::cast_possible_truncation)]
541    fn denoise_nlm_cpu(
542        input: &[u8],
543        output: &mut [u8],
544        width: u32,
545        height: u32,
546        h: f32,
547        patch_radius: u32,
548        search_radius: u32,
549    ) -> Result<()> {
550        if h <= 0.0 {
551            return Err(GpuError::Internal(
552                "NLM filter strength h must be positive".to_string(),
553            ));
554        }
555
556        let w = width as usize;
557        let ht = height as usize;
558        let pr = patch_radius as usize;
559        let sr = search_radius as usize;
560        let h_sq = h * h;
561        let patch_area = ((2 * pr + 1) * (2 * pr + 1)) as f32;
562        let inv_h_sq_patch = 1.0 / (h_sq * patch_area);
563
564        output
565            .par_chunks_exact_mut(4)
566            .enumerate()
567            .for_each(|(i, pixel)| {
568                let px = i % w;
569                let py = i / w;
570                if py >= ht {
571                    return;
572                }
573
574                let mut acc = [0.0f32; 4];
575                let mut weight_sum = 0.0f32;
576
577                // Iterate over search window.
578                for qy in
579                    (py as i64 - sr as i64).max(0)..=(py as i64 + sr as i64).min(ht as i64 - 1)
580                {
581                    for qx in
582                        (px as i64 - sr as i64).max(0)..=(px as i64 + sr as i64).min(w as i64 - 1)
583                    {
584                        // Compute patch distance between (px,py) and (qx,qy).
585                        let mut patch_dist_sq = 0.0f32;
586                        for ky in -(pr as i64)..=(pr as i64) {
587                            for kx in -(pr as i64)..=(pr as i64) {
588                                let p_x = (px as i64 + kx).clamp(0, w as i64 - 1) as usize;
589                                let p_y = (py as i64 + ky).clamp(0, ht as i64 - 1) as usize;
590                                let q_x = (qx + kx).clamp(0, w as i64 - 1) as usize;
591                                let q_y = (qy + ky).clamp(0, ht as i64 - 1) as usize;
592
593                                // Use luma (channel 0) for patch comparison.
594                                let diff = f32::from(input[(p_y * w + p_x) * 4])
595                                    - f32::from(input[(q_y * w + q_x) * 4]);
596                                patch_dist_sq += diff * diff;
597                            }
598                        }
599
600                        let w_nlm = (-patch_dist_sq * inv_h_sq_patch).exp();
601                        weight_sum += w_nlm;
602
603                        for c in 0..4 {
604                            acc[c] +=
605                                w_nlm * f32::from(input[(qy as usize * w + qx as usize) * 4 + c]);
606                        }
607                    }
608                }
609
610                if weight_sum > 0.0 {
611                    for c in 0..4 {
612                        pixel[c] = (acc[c] / weight_sum).round().clamp(0.0, 255.0) as u8;
613                    }
614                } else {
615                    pixel.copy_from_slice(&input[i * 4..i * 4 + 4]);
616                }
617            });
618
619        Ok(())
620    }
621
622    /// Validate that `sigma > 0.0` and return a descriptive error if not.
623    #[allow(dead_code)]
624    fn check_sigma(sigma: f32, name: &str) -> Result<()> {
625        if sigma <= 0.0 {
626            Err(GpuError::Internal(format!(
627                "{name} must be positive, got {sigma}"
628            )))
629        } else {
630            Ok(())
631        }
632    }
633
634    /// Convenience: Gaussian denoise with automatic sigma selection.
635    ///
636    /// `noise_level` is in the range \[0.0, 1.0\] where 0.0 = no noise
637    /// and 1.0 = heavy noise.
638    ///
639    /// # Errors
640    ///
641    /// Returns an error if buffer sizes are invalid.
642    pub fn auto_denoise(
643        device: &GpuDevice,
644        input: &[u8],
645        output: &mut [u8],
646        width: u32,
647        height: u32,
648        noise_level: f32,
649    ) -> Result<()> {
650        let sigma = noise_level.clamp(0.0, 1.0) * 3.0 + 0.5;
651        Self::denoise(
652            device,
653            input,
654            output,
655            width,
656            height,
657            DenoiseAlgorithm::Gaussian { sigma },
658        )
659    }
660}
661
662// ============================================================================
663// Denoise kernel wrappers (kernel module integration)
664// ============================================================================
665
666/// Denoise kernel configuration for use with the `kernels` module.
667#[derive(Debug, Clone)]
668pub struct DenoiseKernel {
669    algorithm: DenoiseAlgorithm,
670}
671
672impl DenoiseKernel {
673    /// Create a new denoise kernel with the given algorithm.
674    #[must_use]
675    pub fn new(algorithm: DenoiseAlgorithm) -> Self {
676        Self { algorithm }
677    }
678
679    /// Create a Gaussian denoise kernel.
680    #[must_use]
681    pub fn gaussian(sigma: f32) -> Self {
682        Self::new(DenoiseAlgorithm::Gaussian { sigma })
683    }
684
685    /// Create a bilateral filter denoise kernel.
686    #[must_use]
687    pub fn bilateral(sigma_spatial: f32, sigma_range: f32) -> Self {
688        Self::new(DenoiseAlgorithm::BilateralFilter {
689            sigma_spatial,
690            sigma_range,
691        })
692    }
693
694    /// Create an NLM denoise kernel.
695    #[must_use]
696    pub fn nlm(h: f32, patch_radius: u32, search_radius: u32) -> Self {
697        Self::new(DenoiseAlgorithm::NonLocalMeans {
698            h,
699            patch_radius,
700            search_radius,
701        })
702    }
703
704    /// Apply this kernel to an RGBA image.
705    ///
706    /// # Errors
707    ///
708    /// Returns an error if buffer sizes are invalid or the operation fails.
709    pub fn apply(
710        &self,
711        device: &GpuDevice,
712        input: &[u8],
713        output: &mut [u8],
714        width: u32,
715        height: u32,
716    ) -> Result<()> {
717        DenoiseOperation::denoise(device, input, output, width, height, self.algorithm)
718    }
719
720    /// Get the algorithm used by this kernel.
721    #[must_use]
722    pub fn algorithm(&self) -> DenoiseAlgorithm {
723        self.algorithm
724    }
725
726    /// Estimate GFLOP for `width × height` frame at this algorithm.
727    #[must_use]
728    pub fn estimate_gflops(&self, width: u32, height: u32) -> f64 {
729        let pixels = u64::from(width) * u64::from(height);
730        let ops: u64 = match self.algorithm {
731            DenoiseAlgorithm::Gaussian { sigma } => {
732                let r = (3.0 * sigma).ceil() as u64;
733                let k = 2 * r + 1;
734                pixels * k * 4 * 4 // 2 passes, ~4 ops/tap, 4 channels
735            }
736            DenoiseAlgorithm::BilateralFilter { sigma_spatial, .. } => {
737                let r = (3.0 * sigma_spatial).ceil() as u64;
738                let k = (2 * r + 1).pow(2);
739                pixels * k * 12 * 4 // exp + mul + add, 4 channels
740            }
741            DenoiseAlgorithm::NonLocalMeans {
742                patch_radius,
743                search_radius,
744                ..
745            } => {
746                let pr = u64::from(2 * patch_radius + 1).pow(2);
747                let sr = u64::from(2 * search_radius + 1).pow(2);
748                pixels * sr * pr * 5 // patch distance + weighting
749            }
750        };
751        ops as f64 / 1e9
752    }
753}
754
755// ============================================================================
756// Tests
757// ============================================================================
758
759#[cfg(test)]
760mod tests {
761    use super::*;
762
763    fn gray_image(w: u32, h: u32, value: u8) -> Vec<u8> {
764        vec![value; (w * h * 4) as usize]
765    }
766
767    fn noisy_image(w: u32, h: u32) -> Vec<u8> {
768        (0..(w * h * 4))
769            .map(|i| (i as u8).wrapping_mul(37))
770            .collect()
771    }
772
773    // ---- DenoiseAlgorithm --------------------------------------------------
774
775    #[test]
776    fn test_gaussian_denoise_cpu_constant_image() {
777        let w = 16u32;
778        let h = 16u32;
779        let input = gray_image(w, h, 200);
780        let mut output = vec![0u8; (w * h * 4) as usize];
781        let result = DenoiseOperation::denoise_gaussian_cpu(&input, &mut output, w, h, 1.5);
782        assert!(result.is_ok());
783        // A constant image should pass through unchanged (all values should be 200).
784        for &v in &output {
785            assert_eq!(v, 200);
786        }
787    }
788
789    #[test]
790    fn test_gaussian_denoise_cpu_noisy() {
791        let w = 32u32;
792        let h = 32u32;
793        let input = noisy_image(w, h);
794        let mut output = vec![0u8; (w * h * 4) as usize];
795        let result = DenoiseOperation::denoise_gaussian_cpu(&input, &mut output, w, h, 2.0);
796        assert!(result.is_ok());
797        // Output should not be all-zeros.
798        assert!(output.iter().any(|&v| v > 0));
799    }
800
801    #[test]
802    fn test_bilateral_denoise_cpu_constant() {
803        let w = 8u32;
804        let h = 8u32;
805        let input = gray_image(w, h, 100);
806        let mut output = vec![0u8; (w * h * 4) as usize];
807        let result = DenoiseOperation::denoise_bilateral_cpu(&input, &mut output, w, h, 1.5, 30.0);
808        assert!(result.is_ok());
809        for &v in &output {
810            assert_eq!(v, 100);
811        }
812    }
813
814    #[test]
815    fn test_nlm_denoise_cpu_constant() {
816        let w = 8u32;
817        let h = 8u32;
818        let input = gray_image(w, h, 150);
819        let mut output = vec![0u8; (w * h * 4) as usize];
820        let result = DenoiseOperation::denoise_nlm_cpu(&input, &mut output, w, h, 10.0, 2, 5);
821        assert!(result.is_ok());
822        for &v in &output {
823            assert_eq!(v, 150);
824        }
825    }
826
827    #[test]
828    fn test_nlm_denoise_invalid_h() {
829        let w = 4u32;
830        let h = 4u32;
831        let input = gray_image(w, h, 0);
832        let mut output = vec![0u8; (w * h * 4) as usize];
833        let result = DenoiseOperation::denoise_nlm_cpu(&input, &mut output, w, h, 0.0, 1, 3);
834        assert!(result.is_err());
835    }
836
837    // ---- DenoiseKernel -----------------------------------------------------
838
839    #[test]
840    fn test_denoise_kernel_gaussian() {
841        let k = DenoiseKernel::gaussian(1.0);
842        assert_eq!(k.algorithm(), DenoiseAlgorithm::Gaussian { sigma: 1.0 });
843    }
844
845    #[test]
846    fn test_denoise_kernel_bilateral() {
847        let k = DenoiseKernel::bilateral(2.0, 25.0);
848        assert_eq!(
849            k.algorithm(),
850            DenoiseAlgorithm::BilateralFilter {
851                sigma_spatial: 2.0,
852                sigma_range: 25.0,
853            }
854        );
855    }
856
857    #[test]
858    fn test_denoise_kernel_nlm() {
859        let k = DenoiseKernel::nlm(10.0, 3, 10);
860        assert_eq!(
861            k.algorithm(),
862            DenoiseAlgorithm::NonLocalMeans {
863                h: 10.0,
864                patch_radius: 3,
865                search_radius: 10,
866            }
867        );
868    }
869
870    #[test]
871    fn test_estimate_gflops_not_zero() {
872        let k = DenoiseKernel::gaussian(1.5);
873        assert!(k.estimate_gflops(1920, 1080) > 0.0);
874
875        let k2 = DenoiseKernel::nlm(10.0, 3, 10);
876        assert!(k2.estimate_gflops(1920, 1080) > 0.0);
877    }
878
879    // ---- GPU bilateral path (uses fallback device in test environments) ----
880
881    #[test]
882    fn test_bilateral_denoise_via_denoise_fn_constant() {
883        // Try to obtain a GPU device; fall back gracefully if unavailable.
884        let device = match GpuDevice::new_fallback() {
885            Ok(d) => d,
886            Err(_) => return, // headless CI with no wgpu adapter — skip
887        };
888
889        let w = 8u32;
890        let h = 8u32;
891        let input = gray_image(w, h, 128);
892        let mut output = vec![0u8; (w * h * 4) as usize];
893
894        let result = DenoiseOperation::denoise(
895            &device,
896            &input,
897            &mut output,
898            w,
899            h,
900            DenoiseAlgorithm::BilateralFilter {
901                sigma_spatial: 1.5,
902                sigma_range: 30.0,
903            },
904        );
905
906        // Constant image must stay constant regardless of GPU/CPU path.
907        assert!(result.is_ok(), "denoise returned error: {:?}", result.err());
908        for &v in &output {
909            assert_eq!(v, 128, "constant image must be preserved");
910        }
911    }
912
913    #[test]
914    fn test_bilateral_denoise_params_struct_size() {
915        // Ensure the CPU-side params layout is exactly 32 bytes to match the WGSL struct.
916        assert_eq!(std::mem::size_of::<BilateralParams>(), 32);
917    }
918}