Skip to main content

oximedia_gpu/ops/
denoise.rs

1//! GPU-accelerated video denoising operations.
2//!
3//! This module provides compute-shader-based denoise filters for video frames.
4//! Three algorithms are available:
5//!
6//! - **Gaussian** – simple spatial blur (low latency, lower quality)
7//! - **`NonLocalMeans`** – patch-based NLM denoising (higher quality, more compute)
8//! - **`BilateralFilter`** – edge-preserving spatial filter (good quality/speed trade-off)
9//!
10//! All operations fall back to CPU SIMD code when no suitable GPU is available
11//! (the `GpuDevice::new` failure path returns an error that the caller may handle
12//! by switching to a software path).
13//!
14//! ## GPU path (bilateral filter)
15//!
16//! When `device.is_fallback == false` the bilateral filter is executed on the GPU
17//! via a WGSL compute shader (`shaders/bilateral.wgsl`).  The NLM algorithm is
18//! too expensive for a single-dispatch compute shader (O(search_area² × patch_area)
19//! per pixel) and is therefore kept as a CPU fallback; GPU NLM is planned for a
20//! future release using multi-pass tiled reduction.
21
22use crate::{
23    shader::{BindGroupLayoutBuilder, ShaderCompiler, ShaderSource},
24    GpuDevice, GpuError, Result,
25};
26use bytemuck::{Pod, Zeroable};
27use once_cell::sync::OnceCell;
28use rayon::prelude::*;
29use wgpu::{BindGroupLayout, ComputePipeline};
30
31// ============================================================================
32// Denoise algorithm selector
33// ============================================================================
34
35/// Denoise algorithm variants.
36#[derive(Debug, Clone, Copy, PartialEq)]
37pub enum DenoiseAlgorithm {
38    /// Gaussian spatial blur.
39    ///
40    /// Parameters:
41    /// - `sigma`: standard deviation of the Gaussian kernel (e.g. 1.5).
42    Gaussian {
43        /// Standard deviation for the Gaussian kernel.
44        sigma: f32,
45    },
46
47    /// Non-Local Means denoising.
48    ///
49    /// Parameters:
50    /// - `h`: filter strength (higher = more smoothing, e.g. 10.0).
51    /// - `patch_radius`: half-size of the comparison patch (e.g. 3).
52    /// - `search_radius`: half-size of the search window (e.g. 10).
53    NonLocalMeans {
54        /// Filter strength (denoising parameter *h*).
55        h: f32,
56        /// Comparison patch half-radius.
57        patch_radius: u32,
58        /// Search window half-radius.
59        search_radius: u32,
60    },
61
62    /// Bilateral filter (edge-preserving).
63    ///
64    /// Parameters:
65    /// - `sigma_spatial`: spatial Gaussian standard deviation.
66    /// - `sigma_range`: range Gaussian standard deviation (pixel value domain).
67    BilateralFilter {
68        /// Spatial Gaussian standard deviation.
69        sigma_spatial: f32,
70        /// Range (colour) Gaussian standard deviation.
71        sigma_range: f32,
72    },
73}
74
75// ============================================================================
76// Parameter structs (GPU-uploadable)
77// ============================================================================
78
79#[repr(C)]
80#[derive(Clone, Copy, Pod, Zeroable)]
81struct GaussianDenoiseParams {
82    width: u32,
83    height: u32,
84    kernel_radius: u32,
85    _pad: u32,
86    sigma: f32,
87    inv_two_sigma_sq: f32,
88    _pad2: [f32; 2],
89}
90
91/// GPU-side uniform layout (must match `BilateralParams` struct in bilateral.wgsl exactly).
92#[repr(C)]
93#[derive(Clone, Copy, Pod, Zeroable)]
94struct BilateralParams {
95    width: u32,
96    height: u32,
97    kernel_radius: u32,
98    _pad: u32,
99    sigma_spatial: f32,
100    sigma_range: f32,
101    inv_two_sigma_s_sq: f32,
102    inv_two_sigma_r_sq: f32,
103}
104
105#[repr(C)]
106#[derive(Clone, Copy, Pod, Zeroable)]
107struct NlmParams {
108    width: u32,
109    height: u32,
110    patch_radius: u32,
111    search_radius: u32,
112    h_sq: f32,
113    inv_patch_area: f32,
114    _pad: [f32; 2],
115}
116
117// ============================================================================
118// Public DenoiseOperation API
119// ============================================================================
120
121/// GPU-accelerated denoise operations.
122///
123/// # Note on GPU vs CPU execution
124///
125/// When a real GPU device is available (`!device.is_fallback`), the bilateral
126/// filter runs as a wgpu compute shader (`bilateral_filter_main` entry point in
127/// `shaders/bilateral.wgsl`).  If the GPU path fails at any point, execution
128/// transparently falls back to the CPU SIMD implementation.
129///
130/// Gaussian and NLM always use the CPU path (NLM GPU path is planned for a
131/// future release).
132pub struct DenoiseOperation;
133
134impl DenoiseOperation {
135    /// Denoise an RGBA image using the selected algorithm.
136    ///
137    /// Both `input` and `output` must be `width * height * 4` bytes
138    /// (packed RGBA, one byte per channel).
139    ///
140    /// # Errors
141    ///
142    /// Returns an error if buffer sizes are invalid.
143    pub fn denoise(
144        device: &GpuDevice,
145        input: &[u8],
146        output: &mut [u8],
147        width: u32,
148        height: u32,
149        algorithm: DenoiseAlgorithm,
150    ) -> Result<()> {
151        super::utils::validate_dimensions(width, height)?;
152        super::utils::validate_buffer_size(input, width, height, 4)?;
153        super::utils::validate_buffer_size(output, width, height, 4)?;
154
155        match algorithm {
156            DenoiseAlgorithm::Gaussian { sigma } => {
157                Self::denoise_gaussian_cpu(input, output, width, height, sigma)
158            }
159            DenoiseAlgorithm::BilateralFilter {
160                sigma_spatial,
161                sigma_range,
162            } => {
163                // Prefer GPU path when the device is a real (non-fallback) adapter.
164                if !device.is_fallback {
165                    match Self::denoise_bilateral_gpu(
166                        device,
167                        input,
168                        output,
169                        width,
170                        height,
171                        sigma_spatial,
172                        sigma_range,
173                    ) {
174                        Ok(()) => return Ok(()),
175                        Err(e) => {
176                            tracing::warn!(
177                                "GPU bilateral filter failed ({e}), falling back to CPU"
178                            );
179                        }
180                    }
181                }
182                // CPU fallback (also used for software adapters).
183                Self::denoise_bilateral_cpu(
184                    input,
185                    output,
186                    width,
187                    height,
188                    sigma_spatial,
189                    sigma_range,
190                )
191            }
192            // NLM GPU path is future work — CPU fallback only.
193            DenoiseAlgorithm::NonLocalMeans {
194                h,
195                patch_radius,
196                search_radius,
197            } => {
198                Self::denoise_nlm_cpu(input, output, width, height, h, patch_radius, search_radius)
199            }
200        }
201    }
202
203    // -----------------------------------------------------------------------
204    // GPU path — bilateral filter
205    // -----------------------------------------------------------------------
206
207    /// Run the bilateral filter on the GPU via wgpu compute shader.
208    #[allow(clippy::cast_possible_truncation)]
209    fn denoise_bilateral_gpu(
210        device: &GpuDevice,
211        input: &[u8],
212        output: &mut [u8],
213        width: u32,
214        height: u32,
215        sigma_spatial: f32,
216        sigma_range: f32,
217    ) -> Result<()> {
218        use super::utils::{
219            calculate_dispatch_size, create_readback_buffer, create_storage_buffer,
220            create_uniform_buffer,
221        };
222
223        let pipeline = Self::get_bilateral_pipeline(device)?;
224        let layout = Self::get_bilateral_bind_group_layout(device)?;
225
226        // Build uniform params buffer.
227        let kernel_radius = (3.0 * sigma_spatial).ceil() as u32;
228        let inv_two_ss_sq = 1.0 / (2.0 * sigma_spatial * sigma_spatial);
229        let inv_two_sr_sq = 1.0 / (2.0 * sigma_range * sigma_range);
230
231        let params = BilateralParams {
232            width,
233            height,
234            kernel_radius,
235            _pad: 0,
236            sigma_spatial,
237            sigma_range,
238            inv_two_sigma_s_sq: inv_two_ss_sq,
239            inv_two_sigma_r_sq: inv_two_sr_sq,
240        };
241
242        // Pack the u8 RGBA input as u32 words (shader reads array<u32>).
243        // Each u32 packs one RGBA pixel: R<<24 | G<<16 | B<<8 | A.
244        let num_pixels = (width * height) as usize;
245        let mut input_u32: Vec<u32> = Vec::with_capacity(num_pixels);
246        for chunk in input.chunks_exact(4) {
247            let packed = ((chunk[0] as u32) << 24)
248                | ((chunk[1] as u32) << 16)
249                | ((chunk[2] as u32) << 8)
250                | (chunk[3] as u32);
251            input_u32.push(packed);
252        }
253
254        let input_bytes = bytemuck::cast_slice(&input_u32);
255        let output_len = num_pixels * 4; // u32 per pixel, 4 bytes each
256
257        let input_buffer = create_storage_buffer(device, input_bytes.len() as u64)?;
258        let output_buffer = create_storage_buffer(device, output_len as u64)?;
259        let params_buffer = create_uniform_buffer(device, bytemuck::bytes_of(&params))?;
260
261        device
262            .queue()
263            .write_buffer(input_buffer.buffer(), 0, input_bytes);
264
265        // Build bind group.
266        let compiler = ShaderCompiler::new(device);
267        let bind_group = compiler.create_bind_group(
268            "Bilateral Bind Group",
269            layout,
270            &[
271                wgpu::BindGroupEntry {
272                    binding: 0,
273                    resource: input_buffer.buffer().as_entire_binding(),
274                },
275                wgpu::BindGroupEntry {
276                    binding: 1,
277                    resource: output_buffer.buffer().as_entire_binding(),
278                },
279                wgpu::BindGroupEntry {
280                    binding: 2,
281                    resource: params_buffer.buffer().as_entire_binding(),
282                },
283            ],
284        );
285
286        // Dispatch compute.
287        {
288            let mut encoder =
289                device
290                    .device()
291                    .create_command_encoder(&wgpu::CommandEncoderDescriptor {
292                        label: Some("Bilateral Compute Encoder"),
293                    });
294            {
295                let mut pass = encoder.begin_compute_pass(&wgpu::ComputePassDescriptor {
296                    label: Some("Bilateral Compute Pass"),
297                    timestamp_writes: None,
298                });
299                pass.set_pipeline(pipeline);
300                pass.set_bind_group(0, &bind_group, &[]);
301                let (dx, dy) = calculate_dispatch_size(width, height, (16, 16));
302                pass.dispatch_workgroups(dx, dy, 1);
303            }
304            device.queue().submit(Some(encoder.finish()));
305        }
306
307        // Copy output_buffer → readback buffer.
308        let readback = create_readback_buffer(device, output_len as u64)?;
309        {
310            let mut encoder =
311                device
312                    .device()
313                    .create_command_encoder(&wgpu::CommandEncoderDescriptor {
314                        label: Some("Bilateral Readback Encoder"),
315                    });
316            output_buffer.copy_to(&mut encoder, &readback, 0, 0, output_len as u64)?;
317            device.queue().submit(Some(encoder.finish()));
318        }
319
320        device.wait();
321
322        // Read back and unpack u32 → RGBA bytes.
323        let raw = readback.read(device, 0, output_len as u64)?;
324        let result_u32: &[u32] = bytemuck::cast_slice(&raw);
325        for (i, &packed) in result_u32.iter().enumerate() {
326            output[i * 4] = ((packed >> 24) & 0xFF) as u8;
327            output[i * 4 + 1] = ((packed >> 16) & 0xFF) as u8;
328            output[i * 4 + 2] = ((packed >> 8) & 0xFF) as u8;
329            output[i * 4 + 3] = (packed & 0xFF) as u8;
330        }
331
332        Ok(())
333    }
334
335    // -----------------------------------------------------------------------
336    // Pipeline management (cached per process lifetime via OnceCell)
337    // -----------------------------------------------------------------------
338
339    fn get_bilateral_bind_group_layout(device: &GpuDevice) -> Result<&'static BindGroupLayout> {
340        static LAYOUT: OnceCell<BindGroupLayout> = OnceCell::new();
341        Ok(LAYOUT.get_or_init(|| {
342            let compiler = ShaderCompiler::new(device);
343            let entries = BindGroupLayoutBuilder::new()
344                .add_storage_buffer_read_only(0) // input
345                .add_storage_buffer(1) // output
346                .add_uniform_buffer(2) // params
347                .build();
348            compiler.create_bind_group_layout("Bilateral Bind Group Layout", &entries)
349        }))
350    }
351
352    fn get_bilateral_pipeline(device: &GpuDevice) -> Result<&'static ComputePipeline> {
353        static PIPELINE: OnceCell<std::result::Result<ComputePipeline, String>> = OnceCell::new();
354        PIPELINE
355            .get_or_init(|| Self::init_bilateral_pipeline(device))
356            .as_ref()
357            .map_err(|e| GpuError::PipelineCreation(e.clone()))
358    }
359
360    fn init_bilateral_pipeline(device: &GpuDevice) -> std::result::Result<ComputePipeline, String> {
361        let compiler = ShaderCompiler::new(device);
362        let shader = compiler
363            .compile(
364                "Bilateral Shader",
365                ShaderSource::Embedded(crate::shader::embedded::BILATERAL_SHADER),
366            )
367            .map_err(|e| format!("Failed to compile bilateral shader: {e}"))?;
368
369        let layout = Self::get_bilateral_bind_group_layout(device)
370            .map_err(|e| format!("Failed to create bilateral bind group layout: {e}"))?;
371
372        compiler
373            .create_pipeline(
374                "Bilateral Pipeline",
375                &shader,
376                "bilateral_filter_main",
377                layout,
378            )
379            .map_err(|e| format!("Failed to create bilateral pipeline: {e}"))
380    }
381
382    // -----------------------------------------------------------------------
383    // CPU SIMD fallback implementations
384    // -----------------------------------------------------------------------
385
386    /// Gaussian denoise via separable 1D convolution.
387    #[allow(clippy::cast_possible_truncation)]
388    fn denoise_gaussian_cpu(
389        input: &[u8],
390        output: &mut [u8],
391        width: u32,
392        height: u32,
393        sigma: f32,
394    ) -> Result<()> {
395        let w = width as usize;
396        let h = height as usize;
397        let radius = (3.0 * sigma).ceil() as usize;
398
399        // Build separable 1D Gaussian kernel.
400        let kernel_size = 2 * radius + 1;
401        let mut kernel = vec![0.0f32; kernel_size];
402        let two_sigma_sq = 2.0 * sigma * sigma;
403        let mut sum = 0.0f32;
404        for (i, k) in kernel.iter_mut().enumerate() {
405            let x = i as f32 - radius as f32;
406            *k = (-(x * x) / two_sigma_sq).exp();
407            sum += *k;
408        }
409        for k in &mut kernel {
410            *k /= sum;
411        }
412
413        // Horizontal pass into temp buffer.
414        let mut temp = vec![0u8; input.len()];
415        temp.par_chunks_exact_mut(w * 4)
416            .enumerate()
417            .for_each(|(y, row)| {
418                if y >= h {
419                    return;
420                }
421                for x in 0..w {
422                    for c in 0..4usize {
423                        let mut acc = 0.0f32;
424                        for (ki, &kv) in kernel.iter().enumerate() {
425                            let sx = (x as i64 + ki as i64 - radius as i64).clamp(0, w as i64 - 1)
426                                as usize;
427                            acc += kv * f32::from(input[(y * w + sx) * 4 + c]);
428                        }
429                        row[x * 4 + c] = acc.round().clamp(0.0, 255.0) as u8;
430                    }
431                }
432            });
433
434        // Vertical pass from temp to output.
435        output
436            .par_chunks_exact_mut(4)
437            .enumerate()
438            .for_each(|(i, pixel)| {
439                let x = i % w;
440                let y = i / w;
441                if y >= h {
442                    return;
443                }
444                for c in 0..4usize {
445                    let mut acc = 0.0f32;
446                    for (ki, &kv) in kernel.iter().enumerate() {
447                        let sy =
448                            (y as i64 + ki as i64 - radius as i64).clamp(0, h as i64 - 1) as usize;
449                        acc += kv * f32::from(temp[(sy * w + x) * 4 + c]);
450                    }
451                    pixel[c] = acc.round().clamp(0.0, 255.0) as u8;
452                }
453            });
454
455        Ok(())
456    }
457
458    /// Bilateral filter — CPU fallback (edge-preserving denoising).
459    ///
460    /// Exposed as `pub(crate)` so that `FilterOperation` can delegate to this
461    /// implementation without duplicating the algorithm.
462    #[allow(clippy::cast_possible_truncation)]
463    pub(crate) fn denoise_bilateral_cpu(
464        input: &[u8],
465        output: &mut [u8],
466        width: u32,
467        height: u32,
468        sigma_spatial: f32,
469        sigma_range: f32,
470    ) -> Result<()> {
471        let w = width as usize;
472        let h = height as usize;
473        let radius = (3.0 * sigma_spatial).ceil() as usize;
474        let inv_two_ss_sq = 1.0 / (2.0 * sigma_spatial * sigma_spatial);
475        let inv_two_sr_sq = 1.0 / (2.0 * sigma_range * sigma_range);
476
477        output
478            .par_chunks_exact_mut(4)
479            .enumerate()
480            .for_each(|(i, pixel)| {
481                let x = i % w;
482                let y = i / w;
483                if y >= h {
484                    return;
485                }
486
487                let center = [
488                    f32::from(input[(y * w + x) * 4]),
489                    f32::from(input[(y * w + x) * 4 + 1]),
490                    f32::from(input[(y * w + x) * 4 + 2]),
491                    f32::from(input[(y * w + x) * 4 + 3]),
492                ];
493
494                let mut acc = [0.0f32; 4];
495                let mut weight_sum = 0.0f32;
496
497                for dy in -(radius as i64)..=(radius as i64) {
498                    for dx in -(radius as i64)..=(radius as i64) {
499                        let sx = (x as i64 + dx).clamp(0, w as i64 - 1) as usize;
500                        let sy = (y as i64 + dy).clamp(0, h as i64 - 1) as usize;
501
502                        let spatial_dist_sq = (dx * dx + dy * dy) as f32;
503                        let w_spatial = (-spatial_dist_sq * inv_two_ss_sq).exp();
504
505                        let neighbor = [
506                            f32::from(input[(sy * w + sx) * 4]),
507                            f32::from(input[(sy * w + sx) * 4 + 1]),
508                            f32::from(input[(sy * w + sx) * 4 + 2]),
509                            f32::from(input[(sy * w + sx) * 4 + 3]),
510                        ];
511
512                        let range_dist_sq = (0..3)
513                            .map(|c| (center[c] - neighbor[c]).powi(2))
514                            .sum::<f32>();
515                        let w_range = (-range_dist_sq * inv_two_sr_sq).exp();
516
517                        let w_total = w_spatial * w_range;
518                        weight_sum += w_total;
519
520                        for c in 0..4 {
521                            acc[c] += w_total * neighbor[c];
522                        }
523                    }
524                }
525
526                if weight_sum > 0.0 {
527                    for c in 0..4 {
528                        pixel[c] = (acc[c] / weight_sum).round().clamp(0.0, 255.0) as u8;
529                    }
530                } else {
531                    pixel.copy_from_slice(&input[i * 4..i * 4 + 4]);
532                }
533            });
534
535        Ok(())
536    }
537
538    /// Non-Local Means denoising (CPU path).
539    ///
540    /// GPU NLM is future work — the O(search_area² × patch_area) per-pixel cost
541    /// requires a multi-pass tiled reduction that does not map naturally to a
542    /// single compute dispatch.
543    #[allow(clippy::cast_possible_truncation)]
544    fn denoise_nlm_cpu(
545        input: &[u8],
546        output: &mut [u8],
547        width: u32,
548        height: u32,
549        h: f32,
550        patch_radius: u32,
551        search_radius: u32,
552    ) -> Result<()> {
553        if h <= 0.0 {
554            return Err(GpuError::Internal(
555                "NLM filter strength h must be positive".to_string(),
556            ));
557        }
558
559        let w = width as usize;
560        let ht = height as usize;
561        let pr = patch_radius as usize;
562        let sr = search_radius as usize;
563        let h_sq = h * h;
564        let patch_area = ((2 * pr + 1) * (2 * pr + 1)) as f32;
565        let inv_h_sq_patch = 1.0 / (h_sq * patch_area);
566
567        output
568            .par_chunks_exact_mut(4)
569            .enumerate()
570            .for_each(|(i, pixel)| {
571                let px = i % w;
572                let py = i / w;
573                if py >= ht {
574                    return;
575                }
576
577                let mut acc = [0.0f32; 4];
578                let mut weight_sum = 0.0f32;
579
580                // Iterate over search window.
581                for qy in
582                    (py as i64 - sr as i64).max(0)..=(py as i64 + sr as i64).min(ht as i64 - 1)
583                {
584                    for qx in
585                        (px as i64 - sr as i64).max(0)..=(px as i64 + sr as i64).min(w as i64 - 1)
586                    {
587                        // Compute patch distance between (px,py) and (qx,qy).
588                        let mut patch_dist_sq = 0.0f32;
589                        for ky in -(pr as i64)..=(pr as i64) {
590                            for kx in -(pr as i64)..=(pr as i64) {
591                                let p_x = (px as i64 + kx).clamp(0, w as i64 - 1) as usize;
592                                let p_y = (py as i64 + ky).clamp(0, ht as i64 - 1) as usize;
593                                let q_x = (qx + kx).clamp(0, w as i64 - 1) as usize;
594                                let q_y = (qy + ky).clamp(0, ht as i64 - 1) as usize;
595
596                                // Use luma (channel 0) for patch comparison.
597                                let diff = f32::from(input[(p_y * w + p_x) * 4])
598                                    - f32::from(input[(q_y * w + q_x) * 4]);
599                                patch_dist_sq += diff * diff;
600                            }
601                        }
602
603                        let w_nlm = (-patch_dist_sq * inv_h_sq_patch).exp();
604                        weight_sum += w_nlm;
605
606                        for c in 0..4 {
607                            acc[c] +=
608                                w_nlm * f32::from(input[(qy as usize * w + qx as usize) * 4 + c]);
609                        }
610                    }
611                }
612
613                if weight_sum > 0.0 {
614                    for c in 0..4 {
615                        pixel[c] = (acc[c] / weight_sum).round().clamp(0.0, 255.0) as u8;
616                    }
617                } else {
618                    pixel.copy_from_slice(&input[i * 4..i * 4 + 4]);
619                }
620            });
621
622        Ok(())
623    }
624
625    /// Validate that `sigma > 0.0` and return a descriptive error if not.
626    #[allow(dead_code)]
627    fn check_sigma(sigma: f32, name: &str) -> Result<()> {
628        if sigma <= 0.0 {
629            Err(GpuError::Internal(format!(
630                "{name} must be positive, got {sigma}"
631            )))
632        } else {
633            Ok(())
634        }
635    }
636
637    /// Convenience: Gaussian denoise with automatic sigma selection.
638    ///
639    /// `noise_level` is in the range \[0.0, 1.0\] where 0.0 = no noise
640    /// and 1.0 = heavy noise.
641    ///
642    /// # Errors
643    ///
644    /// Returns an error if buffer sizes are invalid.
645    pub fn auto_denoise(
646        device: &GpuDevice,
647        input: &[u8],
648        output: &mut [u8],
649        width: u32,
650        height: u32,
651        noise_level: f32,
652    ) -> Result<()> {
653        let sigma = noise_level.clamp(0.0, 1.0) * 3.0 + 0.5;
654        Self::denoise(
655            device,
656            input,
657            output,
658            width,
659            height,
660            DenoiseAlgorithm::Gaussian { sigma },
661        )
662    }
663}
664
665// ============================================================================
666// Denoise kernel wrappers (kernel module integration)
667// ============================================================================
668
669/// Denoise kernel configuration for use with the `kernels` module.
670#[derive(Debug, Clone)]
671pub struct DenoiseKernel {
672    algorithm: DenoiseAlgorithm,
673}
674
675impl DenoiseKernel {
676    /// Create a new denoise kernel with the given algorithm.
677    #[must_use]
678    pub fn new(algorithm: DenoiseAlgorithm) -> Self {
679        Self { algorithm }
680    }
681
682    /// Create a Gaussian denoise kernel.
683    #[must_use]
684    pub fn gaussian(sigma: f32) -> Self {
685        Self::new(DenoiseAlgorithm::Gaussian { sigma })
686    }
687
688    /// Create a bilateral filter denoise kernel.
689    #[must_use]
690    pub fn bilateral(sigma_spatial: f32, sigma_range: f32) -> Self {
691        Self::new(DenoiseAlgorithm::BilateralFilter {
692            sigma_spatial,
693            sigma_range,
694        })
695    }
696
697    /// Create an NLM denoise kernel.
698    #[must_use]
699    pub fn nlm(h: f32, patch_radius: u32, search_radius: u32) -> Self {
700        Self::new(DenoiseAlgorithm::NonLocalMeans {
701            h,
702            patch_radius,
703            search_radius,
704        })
705    }
706
707    /// Apply this kernel to an RGBA image.
708    ///
709    /// # Errors
710    ///
711    /// Returns an error if buffer sizes are invalid or the operation fails.
712    pub fn apply(
713        &self,
714        device: &GpuDevice,
715        input: &[u8],
716        output: &mut [u8],
717        width: u32,
718        height: u32,
719    ) -> Result<()> {
720        DenoiseOperation::denoise(device, input, output, width, height, self.algorithm)
721    }
722
723    /// Get the algorithm used by this kernel.
724    #[must_use]
725    pub fn algorithm(&self) -> DenoiseAlgorithm {
726        self.algorithm
727    }
728
729    /// Estimate GFLOP for `width × height` frame at this algorithm.
730    #[must_use]
731    pub fn estimate_gflops(&self, width: u32, height: u32) -> f64 {
732        let pixels = u64::from(width) * u64::from(height);
733        let ops: u64 = match self.algorithm {
734            DenoiseAlgorithm::Gaussian { sigma } => {
735                let r = (3.0 * sigma).ceil() as u64;
736                let k = 2 * r + 1;
737                pixels * k * 4 * 4 // 2 passes, ~4 ops/tap, 4 channels
738            }
739            DenoiseAlgorithm::BilateralFilter { sigma_spatial, .. } => {
740                let r = (3.0 * sigma_spatial).ceil() as u64;
741                let k = (2 * r + 1).pow(2);
742                pixels * k * 12 * 4 // exp + mul + add, 4 channels
743            }
744            DenoiseAlgorithm::NonLocalMeans {
745                patch_radius,
746                search_radius,
747                ..
748            } => {
749                let pr = u64::from(2 * patch_radius + 1).pow(2);
750                let sr = u64::from(2 * search_radius + 1).pow(2);
751                pixels * sr * pr * 5 // patch distance + weighting
752            }
753        };
754        ops as f64 / 1e9
755    }
756}
757
758// ============================================================================
759// Tests
760// ============================================================================
761
762#[cfg(test)]
763mod tests {
764    use super::*;
765
766    fn gray_image(w: u32, h: u32, value: u8) -> Vec<u8> {
767        vec![value; (w * h * 4) as usize]
768    }
769
770    fn noisy_image(w: u32, h: u32) -> Vec<u8> {
771        (0..(w * h * 4))
772            .map(|i| (i as u8).wrapping_mul(37))
773            .collect()
774    }
775
776    // ---- DenoiseAlgorithm --------------------------------------------------
777
778    #[test]
779    fn test_gaussian_denoise_cpu_constant_image() {
780        let w = 16u32;
781        let h = 16u32;
782        let input = gray_image(w, h, 200);
783        let mut output = vec![0u8; (w * h * 4) as usize];
784        let result = DenoiseOperation::denoise_gaussian_cpu(&input, &mut output, w, h, 1.5);
785        assert!(result.is_ok());
786        // A constant image should pass through unchanged (all values should be 200).
787        for &v in &output {
788            assert_eq!(v, 200);
789        }
790    }
791
792    #[test]
793    fn test_gaussian_denoise_cpu_noisy() {
794        let w = 32u32;
795        let h = 32u32;
796        let input = noisy_image(w, h);
797        let mut output = vec![0u8; (w * h * 4) as usize];
798        let result = DenoiseOperation::denoise_gaussian_cpu(&input, &mut output, w, h, 2.0);
799        assert!(result.is_ok());
800        // Output should not be all-zeros.
801        assert!(output.iter().any(|&v| v > 0));
802    }
803
804    #[test]
805    fn test_bilateral_denoise_cpu_constant() {
806        let w = 8u32;
807        let h = 8u32;
808        let input = gray_image(w, h, 100);
809        let mut output = vec![0u8; (w * h * 4) as usize];
810        let result = DenoiseOperation::denoise_bilateral_cpu(&input, &mut output, w, h, 1.5, 30.0);
811        assert!(result.is_ok());
812        for &v in &output {
813            assert_eq!(v, 100);
814        }
815    }
816
817    #[test]
818    fn test_nlm_denoise_cpu_constant() {
819        let w = 8u32;
820        let h = 8u32;
821        let input = gray_image(w, h, 150);
822        let mut output = vec![0u8; (w * h * 4) as usize];
823        let result = DenoiseOperation::denoise_nlm_cpu(&input, &mut output, w, h, 10.0, 2, 5);
824        assert!(result.is_ok());
825        for &v in &output {
826            assert_eq!(v, 150);
827        }
828    }
829
830    #[test]
831    fn test_nlm_denoise_invalid_h() {
832        let w = 4u32;
833        let h = 4u32;
834        let input = gray_image(w, h, 0);
835        let mut output = vec![0u8; (w * h * 4) as usize];
836        let result = DenoiseOperation::denoise_nlm_cpu(&input, &mut output, w, h, 0.0, 1, 3);
837        assert!(result.is_err());
838    }
839
840    // ---- DenoiseKernel -----------------------------------------------------
841
842    #[test]
843    fn test_denoise_kernel_gaussian() {
844        let k = DenoiseKernel::gaussian(1.0);
845        assert_eq!(k.algorithm(), DenoiseAlgorithm::Gaussian { sigma: 1.0 });
846    }
847
848    #[test]
849    fn test_denoise_kernel_bilateral() {
850        let k = DenoiseKernel::bilateral(2.0, 25.0);
851        assert_eq!(
852            k.algorithm(),
853            DenoiseAlgorithm::BilateralFilter {
854                sigma_spatial: 2.0,
855                sigma_range: 25.0,
856            }
857        );
858    }
859
860    #[test]
861    fn test_denoise_kernel_nlm() {
862        let k = DenoiseKernel::nlm(10.0, 3, 10);
863        assert_eq!(
864            k.algorithm(),
865            DenoiseAlgorithm::NonLocalMeans {
866                h: 10.0,
867                patch_radius: 3,
868                search_radius: 10,
869            }
870        );
871    }
872
873    #[test]
874    fn test_estimate_gflops_not_zero() {
875        let k = DenoiseKernel::gaussian(1.5);
876        assert!(k.estimate_gflops(1920, 1080) > 0.0);
877
878        let k2 = DenoiseKernel::nlm(10.0, 3, 10);
879        assert!(k2.estimate_gflops(1920, 1080) > 0.0);
880    }
881
882    // ---- GPU bilateral path (uses fallback device in test environments) ----
883
884    #[test]
885    fn test_bilateral_denoise_via_denoise_fn_constant() {
886        // Try to obtain a GPU device; fall back gracefully if unavailable.
887        let device = match GpuDevice::new_fallback() {
888            Ok(d) => d,
889            Err(_) => return, // headless CI with no wgpu adapter — skip
890        };
891
892        let w = 8u32;
893        let h = 8u32;
894        let input = gray_image(w, h, 128);
895        let mut output = vec![0u8; (w * h * 4) as usize];
896
897        let result = DenoiseOperation::denoise(
898            &device,
899            &input,
900            &mut output,
901            w,
902            h,
903            DenoiseAlgorithm::BilateralFilter {
904                sigma_spatial: 1.5,
905                sigma_range: 30.0,
906            },
907        );
908
909        // Constant image must stay constant regardless of GPU/CPU path.
910        assert!(result.is_ok(), "denoise returned error: {:?}", result.err());
911        for &v in &output {
912            assert_eq!(v, 128, "constant image must be preserved");
913        }
914    }
915
916    #[test]
917    fn test_bilateral_denoise_params_struct_size() {
918        // Ensure the CPU-side params layout is exactly 32 bytes to match the WGSL struct.
919        assert_eq!(std::mem::size_of::<BilateralParams>(), 32);
920    }
921}