Skip to main content

oximedia_gpu/ops/
filter.rs

1//! Convolution filter operations (blur, sharpen, edge detection)
2
3use crate::{
4    shader::{BindGroupLayoutBuilder, ShaderCompiler, ShaderSource},
5    GpuDevice, GpuError, Result,
6};
7use bytemuck::{Pod, Zeroable};
8use once_cell::sync::OnceCell;
9use wgpu::{BindGroup, BindGroupLayout, ComputePipeline};
10
11use super::utils;
12
13#[repr(C)]
14#[derive(Copy, Clone, Pod, Zeroable)]
15struct FilterParams {
16    width: u32,
17    height: u32,
18    stride: u32,
19    kernel_size: u32,
20    normalize: u32,
21    filter_type: u32,
22    padding: u32,
23    sigma: f32,
24}
25
26/// Convolution filter operations
27pub struct FilterOperation;
28
29impl FilterOperation {
30    /// Apply Gaussian blur
31    ///
32    /// # Arguments
33    ///
34    /// * `device` - GPU device
35    /// * `input` - Input image buffer (packed RGBA format)
36    /// * `output` - Output image buffer (packed RGBA format)
37    /// * `width` - Image width
38    /// * `height` - Image height
39    /// * `sigma` - Blur radius (standard deviation)
40    ///
41    /// # Errors
42    ///
43    /// Returns an error if buffer sizes are invalid or if the GPU operation fails.
44    #[allow(clippy::too_many_arguments)]
45    pub fn gaussian_blur(
46        device: &GpuDevice,
47        input: &[u8],
48        output: &mut [u8],
49        width: u32,
50        height: u32,
51        sigma: f32,
52    ) -> Result<()> {
53        utils::validate_dimensions(width, height)?;
54        utils::validate_buffer_size(input, width, height, 4)?;
55        utils::validate_buffer_size(output, width, height, 4)?;
56
57        let kernel_size = Self::calculate_kernel_size(sigma);
58        let pipeline = Self::get_gaussian_pipeline(device)?;
59        let layout = Self::get_bind_group_layout(device)?;
60
61        Self::execute_filter(
62            device,
63            pipeline,
64            layout,
65            input,
66            output,
67            width,
68            height,
69            kernel_size,
70            1, // Gaussian filter type
71            sigma,
72        )
73    }
74
75    /// Apply sharpening filter (unsharp mask)
76    ///
77    /// # Arguments
78    ///
79    /// * `device` - GPU device
80    /// * `input` - Input image buffer (packed RGBA format)
81    /// * `output` - Output image buffer (packed RGBA format)
82    /// * `width` - Image width
83    /// * `height` - Image height
84    /// * `amount` - Sharpening strength
85    ///
86    /// # Errors
87    ///
88    /// Returns an error if buffer sizes are invalid or if the GPU operation fails.
89    #[allow(clippy::too_many_arguments)]
90    pub fn sharpen(
91        device: &GpuDevice,
92        input: &[u8],
93        output: &mut [u8],
94        width: u32,
95        height: u32,
96        amount: f32,
97    ) -> Result<()> {
98        utils::validate_dimensions(width, height)?;
99        utils::validate_buffer_size(input, width, height, 4)?;
100        utils::validate_buffer_size(output, width, height, 4)?;
101
102        let pipeline = Self::get_sharpen_pipeline(device)?;
103        let layout = Self::get_bind_group_layout(device)?;
104
105        Self::execute_filter(
106            device, pipeline, layout, input, output, width, height,
107            5, // Kernel size for sharpening
108            2, // Sharpen filter type
109            amount,
110        )
111    }
112
113    /// Detect edges using Sobel operator
114    ///
115    /// # Arguments
116    ///
117    /// * `device` - GPU device
118    /// * `input` - Input image buffer (packed RGBA format)
119    /// * `output` - Output image buffer (packed RGBA format)
120    /// * `width` - Image width
121    /// * `height` - Image height
122    ///
123    /// # Errors
124    ///
125    /// Returns an error if buffer sizes are invalid or if the GPU operation fails.
126    pub fn edge_detect(
127        device: &GpuDevice,
128        input: &[u8],
129        output: &mut [u8],
130        width: u32,
131        height: u32,
132    ) -> Result<()> {
133        utils::validate_dimensions(width, height)?;
134        utils::validate_buffer_size(input, width, height, 4)?;
135        utils::validate_buffer_size(output, width, height, 4)?;
136
137        let pipeline = Self::get_edge_detect_pipeline(device)?;
138        let layout = Self::get_bind_group_layout(device)?;
139
140        Self::execute_filter(
141            device, pipeline, layout, input, output, width, height, 3, // 3x3 Sobel kernel
142            3, // Edge detect filter type
143            0.0,
144        )
145    }
146
147    /// Apply custom convolution kernel
148    ///
149    /// # Arguments
150    ///
151    /// * `device` - GPU device
152    /// * `input` - Input image buffer (packed RGBA format)
153    /// * `output` - Output image buffer (packed RGBA format)
154    /// * `width` - Image width
155    /// * `height` - Image height
156    /// * `kernel` - Convolution kernel (must be square and odd-sized)
157    /// * `normalize` - Whether to normalize the kernel
158    ///
159    /// # Errors
160    ///
161    /// Returns an error if buffer sizes are invalid or if the GPU operation fails.
162    #[allow(clippy::too_many_arguments)]
163    pub fn convolve(
164        device: &GpuDevice,
165        input: &[u8],
166        output: &mut [u8],
167        width: u32,
168        height: u32,
169        kernel: &[f32],
170        normalize: bool,
171    ) -> Result<()> {
172        utils::validate_dimensions(width, height)?;
173        utils::validate_buffer_size(input, width, height, 4)?;
174        utils::validate_buffer_size(output, width, height, 4)?;
175
176        let kernel_size = (kernel.len() as f32).sqrt() as u32;
177        if kernel_size * kernel_size != kernel.len() as u32 {
178            return Err(GpuError::Internal("Kernel must be square".to_string()));
179        }
180        if kernel_size % 2 == 0 {
181            return Err(GpuError::Internal("Kernel size must be odd".to_string()));
182        }
183
184        let pipeline = Self::get_convolve_pipeline(device)?;
185        let layout = Self::get_bind_group_layout_with_kernel(device)?;
186
187        Self::execute_convolve(
188            device,
189            pipeline,
190            layout,
191            input,
192            output,
193            width,
194            height,
195            kernel,
196            kernel_size,
197            normalize,
198        )
199    }
200
201    #[allow(clippy::too_many_arguments)]
202    fn execute_filter(
203        device: &GpuDevice,
204        pipeline: &ComputePipeline,
205        layout: &BindGroupLayout,
206        input: &[u8],
207        output: &mut [u8],
208        width: u32,
209        height: u32,
210        kernel_size: u32,
211        filter_type: u32,
212        sigma: f32,
213    ) -> Result<()> {
214        // Create buffers
215        let input_buffer = utils::create_storage_buffer(device, input.len() as u64)?;
216        let output_buffer = utils::create_storage_buffer(device, output.len() as u64)?;
217
218        // Upload input data
219        device.queue().write_buffer(input_buffer.buffer(), 0, input);
220
221        // Create uniform buffer for parameters
222        let params = FilterParams {
223            width,
224            height,
225            stride: width,
226            kernel_size,
227            normalize: 1,
228            filter_type,
229            padding: 0,
230            sigma,
231        };
232        let params_bytes = bytemuck::bytes_of(&params);
233        let params_buffer = utils::create_uniform_buffer(device, params_bytes)?;
234
235        // Create bind group
236        let compiler = ShaderCompiler::new(device);
237        let bind_group = compiler.create_bind_group(
238            "Filter Bind Group",
239            layout,
240            &[
241                wgpu::BindGroupEntry {
242                    binding: 0,
243                    resource: input_buffer.buffer().as_entire_binding(),
244                },
245                wgpu::BindGroupEntry {
246                    binding: 1,
247                    resource: output_buffer.buffer().as_entire_binding(),
248                },
249                wgpu::BindGroupEntry {
250                    binding: 2,
251                    resource: params_buffer.buffer().as_entire_binding(),
252                },
253            ],
254        );
255
256        // Execute compute pass
257        Self::dispatch_compute(device, pipeline, &bind_group, width, height)?;
258
259        // Read back results
260        let readback_buffer = utils::create_readback_buffer(device, output.len() as u64)?;
261        let mut encoder = device
262            .device()
263            .create_command_encoder(&wgpu::CommandEncoderDescriptor {
264                label: Some("Filter Copy Encoder"),
265            });
266
267        output_buffer.copy_to(&mut encoder, &readback_buffer, 0, 0, output.len() as u64)?;
268
269        device.queue().submit(Some(encoder.finish()));
270        device.wait();
271
272        let result = readback_buffer.read(device, 0, output.len() as u64)?;
273        output.copy_from_slice(&result);
274
275        Ok(())
276    }
277
278    #[allow(clippy::too_many_arguments)]
279    fn execute_convolve(
280        device: &GpuDevice,
281        pipeline: &ComputePipeline,
282        layout: &BindGroupLayout,
283        input: &[u8],
284        output: &mut [u8],
285        width: u32,
286        height: u32,
287        kernel: &[f32],
288        kernel_size: u32,
289        normalize: bool,
290    ) -> Result<()> {
291        // Create buffers
292        let input_buffer = utils::create_storage_buffer(device, input.len() as u64)?;
293        let output_buffer = utils::create_storage_buffer(device, output.len() as u64)?;
294
295        // Upload input data
296        device.queue().write_buffer(input_buffer.buffer(), 0, input);
297
298        // Create kernel buffer
299        let kernel_bytes = bytemuck::cast_slice(kernel);
300        let kernel_buffer = utils::create_storage_buffer(device, kernel_bytes.len() as u64)?;
301        device
302            .queue()
303            .write_buffer(kernel_buffer.buffer(), 0, kernel_bytes);
304
305        // Create uniform buffer for parameters
306        let params = FilterParams {
307            width,
308            height,
309            stride: width,
310            kernel_size,
311            normalize: u32::from(normalize),
312            filter_type: 0, // Custom kernel
313            padding: 0,
314            sigma: 0.0,
315        };
316        let params_bytes = bytemuck::bytes_of(&params);
317        let params_buffer = utils::create_uniform_buffer(device, params_bytes)?;
318
319        // Create bind group
320        let compiler = ShaderCompiler::new(device);
321        let bind_group = compiler.create_bind_group(
322            "Filter Bind Group",
323            layout,
324            &[
325                wgpu::BindGroupEntry {
326                    binding: 0,
327                    resource: input_buffer.buffer().as_entire_binding(),
328                },
329                wgpu::BindGroupEntry {
330                    binding: 1,
331                    resource: output_buffer.buffer().as_entire_binding(),
332                },
333                wgpu::BindGroupEntry {
334                    binding: 2,
335                    resource: params_buffer.buffer().as_entire_binding(),
336                },
337                wgpu::BindGroupEntry {
338                    binding: 3,
339                    resource: kernel_buffer.buffer().as_entire_binding(),
340                },
341            ],
342        );
343
344        // Execute compute pass
345        Self::dispatch_compute(device, pipeline, &bind_group, width, height)?;
346
347        // Read back results
348        let readback_buffer = utils::create_readback_buffer(device, output.len() as u64)?;
349        let mut encoder = device
350            .device()
351            .create_command_encoder(&wgpu::CommandEncoderDescriptor {
352                label: Some("Filter Copy Encoder"),
353            });
354
355        output_buffer.copy_to(&mut encoder, &readback_buffer, 0, 0, output.len() as u64)?;
356
357        device.queue().submit(Some(encoder.finish()));
358        device.wait();
359
360        let result = readback_buffer.read(device, 0, output.len() as u64)?;
361        output.copy_from_slice(&result);
362
363        Ok(())
364    }
365
366    fn dispatch_compute(
367        device: &GpuDevice,
368        pipeline: &ComputePipeline,
369        bind_group: &BindGroup,
370        width: u32,
371        height: u32,
372    ) -> Result<()> {
373        let mut encoder = device
374            .device()
375            .create_command_encoder(&wgpu::CommandEncoderDescriptor {
376                label: Some("Filter Compute Encoder"),
377            });
378
379        {
380            let mut compute_pass = encoder.begin_compute_pass(&wgpu::ComputePassDescriptor {
381                label: Some("Filter Compute Pass"),
382                timestamp_writes: None,
383            });
384
385            compute_pass.set_pipeline(pipeline);
386            compute_pass.set_bind_group(0, bind_group, &[]);
387
388            let (dispatch_x, dispatch_y) = utils::calculate_dispatch_size(width, height, (16, 16));
389            compute_pass.dispatch_workgroups(dispatch_x, dispatch_y, 1);
390        }
391
392        device.queue().submit(Some(encoder.finish()));
393        Ok(())
394    }
395
396    fn calculate_kernel_size(sigma: f32) -> u32 {
397        // Use 3-sigma rule: kernel size = 2 * ceil(3 * sigma) + 1
398        let radius = (3.0 * sigma).ceil() as u32;
399        2 * radius + 1
400    }
401
402    fn get_bind_group_layout(device: &GpuDevice) -> Result<&'static BindGroupLayout> {
403        static LAYOUT: OnceCell<BindGroupLayout> = OnceCell::new();
404
405        Ok(LAYOUT.get_or_init(|| {
406            let compiler = ShaderCompiler::new(device);
407            let entries = BindGroupLayoutBuilder::new()
408                .add_storage_buffer_read_only(0) // input
409                .add_storage_buffer(1) // output
410                .add_uniform_buffer(2) // params
411                .build();
412
413            compiler.create_bind_group_layout("Filter Bind Group Layout", &entries)
414        }))
415    }
416
417    fn get_bind_group_layout_with_kernel(device: &GpuDevice) -> Result<&'static BindGroupLayout> {
418        static LAYOUT: OnceCell<BindGroupLayout> = OnceCell::new();
419
420        Ok(LAYOUT.get_or_init(|| {
421            let compiler = ShaderCompiler::new(device);
422            let entries = BindGroupLayoutBuilder::new()
423                .add_storage_buffer_read_only(0) // input
424                .add_storage_buffer(1) // output
425                .add_uniform_buffer(2) // params
426                .add_storage_buffer_read_only(3) // kernel
427                .build();
428
429            compiler.create_bind_group_layout("Filter Bind Group Layout (with kernel)", &entries)
430        }))
431    }
432
433    fn init_pipeline(
434        device: &GpuDevice,
435        name: &str,
436        entry_point: &str,
437        layout_fn: fn(&GpuDevice) -> Result<&'static BindGroupLayout>,
438    ) -> std::result::Result<ComputePipeline, String> {
439        let compiler = ShaderCompiler::new(device);
440        let shader = compiler
441            .compile(
442                "Filter Shader",
443                ShaderSource::Embedded(crate::shader::embedded::FILTER_SHADER),
444            )
445            .map_err(|e| format!("Failed to compile filter shader: {e}"))?;
446
447        let layout =
448            layout_fn(device).map_err(|e| format!("Failed to create bind group layout: {e}"))?;
449
450        compiler
451            .create_pipeline(name, &shader, entry_point, layout)
452            .map_err(|e| format!("Failed to create pipeline: {e}"))
453    }
454
455    fn get_gaussian_pipeline(device: &GpuDevice) -> Result<&'static ComputePipeline> {
456        static PIPELINE: OnceCell<std::result::Result<ComputePipeline, String>> = OnceCell::new();
457
458        PIPELINE
459            .get_or_init(|| {
460                FilterOperation::init_pipeline(
461                    device,
462                    "Gaussian Blur Pipeline",
463                    "convolve_main",
464                    Self::get_bind_group_layout,
465                )
466            })
467            .as_ref()
468            .map_err(|e| crate::GpuError::PipelineCreation(e.clone()))
469    }
470
471    fn get_sharpen_pipeline(device: &GpuDevice) -> Result<&'static ComputePipeline> {
472        static PIPELINE: OnceCell<std::result::Result<ComputePipeline, String>> = OnceCell::new();
473
474        PIPELINE
475            .get_or_init(|| {
476                FilterOperation::init_pipeline(
477                    device,
478                    "Sharpen Pipeline",
479                    "unsharp_mask",
480                    Self::get_bind_group_layout,
481                )
482            })
483            .as_ref()
484            .map_err(|e| crate::GpuError::PipelineCreation(e.clone()))
485    }
486
487    fn get_edge_detect_pipeline(device: &GpuDevice) -> Result<&'static ComputePipeline> {
488        static PIPELINE: OnceCell<std::result::Result<ComputePipeline, String>> = OnceCell::new();
489
490        PIPELINE
491            .get_or_init(|| {
492                FilterOperation::init_pipeline(
493                    device,
494                    "Edge Detect Pipeline",
495                    "edge_detect",
496                    Self::get_bind_group_layout,
497                )
498            })
499            .as_ref()
500            .map_err(|e| crate::GpuError::PipelineCreation(e.clone()))
501    }
502
503    fn get_convolve_pipeline(device: &GpuDevice) -> Result<&'static ComputePipeline> {
504        static PIPELINE: OnceCell<std::result::Result<ComputePipeline, String>> = OnceCell::new();
505
506        PIPELINE
507            .get_or_init(|| {
508                FilterOperation::init_pipeline(
509                    device,
510                    "Convolve Pipeline",
511                    "convolve_main",
512                    Self::get_bind_group_layout_with_kernel,
513                )
514            })
515            .as_ref()
516            .map_err(|e| crate::GpuError::PipelineCreation(e.clone()))
517    }
518}
519
520// ---------------------------------------------------------------------------
521// Separable CPU Gaussian blur (Task 9)
522// ---------------------------------------------------------------------------
523
524/// Build a normalised 1-D Gaussian kernel of radius `ceil(3σ)`.
525///
526/// The returned `Vec<f32>` has length `2*radius+1` and sums to 1.0.
527/// If `sigma` is ≤ 0 a single-element identity kernel `[1.0]` is returned.
528#[must_use]
529pub fn gaussian_kernel_1d(sigma: f32) -> Vec<f32> {
530    if sigma <= 0.0 {
531        return vec![1.0_f32];
532    }
533    let radius = (3.0 * sigma).ceil() as usize;
534    let len = 2 * radius + 1;
535    let mut kernel = Vec::with_capacity(len);
536    let two_sigma_sq = 2.0 * sigma * sigma;
537    let mut sum = 0.0_f32;
538    for i in 0..len {
539        let x = i as f32 - radius as f32;
540        let v = (-x * x / two_sigma_sq).exp();
541        kernel.push(v);
542        sum += v;
543    }
544    for k in &mut kernel {
545        *k /= sum;
546    }
547    kernel
548}
549
550/// CPU-side separable Gaussian blur (two-pass: horizontal then vertical).
551///
552/// # Arguments
553///
554/// * `input`  — source RGBA bytes (`width × height × 4`)
555/// * `output` — destination RGBA bytes (same size as input)
556/// * `width`, `height` — image dimensions
557/// * `sigma`  — Gaussian standard deviation in pixels (> 0)
558///
559/// # Errors
560///
561/// Returns an error if buffer sizes do not match `width × height × 4`.
562pub fn gaussian_blur_separable(
563    input: &[u8],
564    output: &mut [u8],
565    width: u32,
566    height: u32,
567    sigma: f32,
568) -> crate::Result<()> {
569    utils::validate_dimensions(width, height)?;
570    utils::validate_buffer_size(input, width, height, 4)?;
571    utils::validate_buffer_size(output, width, height, 4)?;
572
573    let w = width as usize;
574    let h = height as usize;
575    let kernel = gaussian_kernel_1d(sigma);
576    let radius = kernel.len() / 2;
577
578    // Horizontal pass — accumulate into f32 buffer to avoid clamping artefacts
579    let mut h_pass = vec![0.0_f32; w * h * 4];
580    for row in 0..h {
581        for col in 0..w {
582            let mut acc = [0.0_f32; 4];
583            let mut wsum = 0.0_f32;
584            for (ki, &kw) in kernel.iter().enumerate() {
585                let sc = col as isize + ki as isize - radius as isize;
586                if sc < 0 || sc >= w as isize {
587                    continue;
588                }
589                let src = (row * w + sc as usize) * 4;
590                for c in 0..4 {
591                    acc[c] += kw * input[src + c] as f32;
592                }
593                wsum += kw;
594            }
595            let dst = (row * w + col) * 4;
596            let inv = if wsum > 0.0 { 1.0 / wsum } else { 1.0 };
597            for c in 0..4 {
598                h_pass[dst + c] = acc[c] * inv;
599            }
600        }
601    }
602
603    // Vertical pass — read from h_pass, write to output
604    for row in 0..h {
605        for col in 0..w {
606            let mut acc = [0.0_f32; 4];
607            let mut wsum = 0.0_f32;
608            for (ki, &kw) in kernel.iter().enumerate() {
609                let sr = row as isize + ki as isize - radius as isize;
610                if sr < 0 || sr >= h as isize {
611                    continue;
612                }
613                let src = (sr as usize * w + col) * 4;
614                for c in 0..4 {
615                    acc[c] += kw * h_pass[src + c];
616                }
617                wsum += kw;
618            }
619            let dst = (row * w + col) * 4;
620            let inv = if wsum > 0.0 { 1.0 / wsum } else { 1.0 };
621            for c in 0..4 {
622                output[dst + c] = (acc[c] * inv).round().clamp(0.0, 255.0) as u8;
623            }
624        }
625    }
626
627    Ok(())
628}
629
630use rayon::prelude::*;
631
632/// CPU-side separable Gaussian blur with Rayon parallel row/column processing.
633///
634/// This is an optimised variant of [`gaussian_blur_separable`] that
635/// parallelises both the horizontal and vertical passes using Rayon.
636///
637/// # Arguments
638///
639/// * `input`  — source RGBA bytes (`width × height × 4`)
640/// * `output` — destination RGBA bytes (same size as input)
641/// * `width`, `height` — image dimensions
642/// * `sigma`  — Gaussian standard deviation in pixels (> 0)
643///
644/// # Errors
645///
646/// Returns an error if buffer sizes do not match `width × height × 4`.
647pub fn gaussian_blur_separable_parallel(
648    input: &[u8],
649    output: &mut [u8],
650    width: u32,
651    height: u32,
652    sigma: f32,
653) -> crate::Result<()> {
654    utils::validate_dimensions(width, height)?;
655    utils::validate_buffer_size(input, width, height, 4)?;
656    utils::validate_buffer_size(output, width, height, 4)?;
657
658    let w = width as usize;
659    let h = height as usize;
660    let kernel = gaussian_kernel_1d(sigma);
661    let radius = kernel.len() / 2;
662
663    // Horizontal pass (parallel over rows)
664    let mut h_pass = vec![0.0_f32; w * h * 4];
665    h_pass
666        .par_chunks_exact_mut(w * 4)
667        .enumerate()
668        .for_each(|(row, row_out)| {
669            for col in 0..w {
670                let mut acc = [0.0_f32; 4];
671                let mut wsum = 0.0_f32;
672                for (ki, &kw) in kernel.iter().enumerate() {
673                    let sc = col as isize + ki as isize - radius as isize;
674                    if sc < 0 || sc >= w as isize {
675                        continue;
676                    }
677                    let src = (row * w + sc as usize) * 4;
678                    for c in 0..4 {
679                        acc[c] += kw * input[src + c] as f32;
680                    }
681                    wsum += kw;
682                }
683                let inv = if wsum > 0.0 { 1.0 / wsum } else { 1.0 };
684                let dst = col * 4;
685                for c in 0..4 {
686                    row_out[dst + c] = acc[c] * inv;
687                }
688            }
689        });
690
691    // Vertical pass (parallel over columns)
692    output
693        .par_chunks_exact_mut(4)
694        .enumerate()
695        .for_each(|(px_idx, px_out)| {
696            let row = px_idx / w;
697            let col = px_idx % w;
698            let mut acc = [0.0_f32; 4];
699            let mut wsum = 0.0_f32;
700            for (ki, &kw) in kernel.iter().enumerate() {
701                let sr = row as isize + ki as isize - radius as isize;
702                if sr < 0 || sr >= h as isize {
703                    continue;
704                }
705                let src = (sr as usize * w + col) * 4;
706                for c in 0..4 {
707                    acc[c] += kw * h_pass[src + c];
708                }
709                wsum += kw;
710            }
711            let inv = if wsum > 0.0 { 1.0 / wsum } else { 1.0 };
712            for c in 0..4 {
713                px_out[c] = (acc[c] * inv).round().clamp(0.0, 255.0) as u8;
714            }
715        });
716
717    Ok(())
718}
719
720// ---------------------------------------------------------------------------
721// CPU-side box blur (separable sliding-sum, O(w*h) per channel)
722// ---------------------------------------------------------------------------
723
724/// CPU box blur: separable two-pass sliding-sum, O(w×h) per channel.
725///
726/// Border pixels are handled via clamped (replicate-border) indexing.
727/// The horizontal pass builds an intermediate `u32` buffer; the vertical pass
728/// writes into the final `Vec<u8>`.
729///
730/// # Errors
731///
732/// Returns an error if the buffer length does not match `width × height × channels`.
733pub fn box_blur(
734    data: &[u8],
735    width: u32,
736    height: u32,
737    channels: u32,
738    radius: u32,
739) -> crate::Result<Vec<u8>> {
740    let w = width as usize;
741    let h = height as usize;
742    let ch = channels as usize;
743    let expected = w * h * ch;
744    if data.len() != expected {
745        return Err(crate::GpuError::InvalidBufferSize {
746            expected,
747            actual: data.len(),
748        });
749    }
750    if w == 0 || h == 0 {
751        return Ok(data.to_vec());
752    }
753
754    let r = radius as isize;
755
756    // --- Horizontal pass ---
757    // For each (row, col, channel): average over columns [col-r .. col+r] (clamped).
758    // Use a sliding-sum that tracks left/right clamped edges.
759    let mut h_pass = vec![0u32; w * h * ch];
760    for row in 0..h {
761        for c in 0..ch {
762            // Build initial window sum for col = 0.
763            let right0 = r.min(w as isize - 1) as usize;
764            let mut window_sum: u32 = 0;
765            for kc in 0..=right0 {
766                window_sum += u32::from(data[(row * w + kc) * ch + c]);
767            }
768
769            for col in 0..w {
770                // Compute the actual left/right clamped boundaries for this col.
771                let left = (col as isize - r).max(0) as usize;
772                let right = (col as isize + r).min(w as isize - 1) as usize;
773
774                if col > 0 {
775                    // Previous column's boundaries.
776                    let prev_left = ((col as isize - 1) - r).max(0) as usize;
777                    let prev_right = ((col as isize - 1) + r).min(w as isize - 1) as usize;
778                    // Remove pixel that dropped off the left.
779                    if left > prev_left {
780                        window_sum -= u32::from(data[(row * w + prev_left) * ch + c]);
781                    }
782                    // Add pixel that entered on the right.
783                    if right > prev_right {
784                        window_sum += u32::from(data[(row * w + right) * ch + c]);
785                    }
786                }
787
788                let window_len = (right - left + 1) as u32;
789                // Round-to-nearest division.
790                h_pass[(row * w + col) * ch + c] = (window_sum + window_len / 2) / window_len;
791            }
792        }
793    }
794
795    // --- Vertical pass ---
796    // For each (row, col, channel): average over rows [row-r .. row+r] (clamped).
797    let mut output = vec![0u8; expected];
798    for col in 0..w {
799        for c in 0..ch {
800            // Build initial window sum for row = 0.
801            let bot0 = r.min(h as isize - 1) as usize;
802            let mut window_sum: u32 = 0;
803            for kr in 0..=bot0 {
804                window_sum += h_pass[(kr * w + col) * ch + c];
805            }
806
807            for row in 0..h {
808                let top = (row as isize - r).max(0) as usize;
809                let bot = (row as isize + r).min(h as isize - 1) as usize;
810
811                if row > 0 {
812                    let prev_top = ((row as isize - 1) - r).max(0) as usize;
813                    let prev_bot = ((row as isize - 1) + r).min(h as isize - 1) as usize;
814                    if top > prev_top {
815                        window_sum -= h_pass[(prev_top * w + col) * ch + c];
816                    }
817                    if bot > prev_bot {
818                        window_sum += h_pass[(bot * w + col) * ch + c];
819                    }
820                }
821
822                let window_len = (bot - top + 1) as u32;
823                let avg = (window_sum + window_len / 2) / window_len;
824                output[(row * w + col) * ch + c] = avg.clamp(0, 255) as u8;
825            }
826        }
827    }
828
829    Ok(output)
830}
831
832// ---------------------------------------------------------------------------
833// CPU-side median filter
834// ---------------------------------------------------------------------------
835
836/// CPU median filter: sorts a `(2r+1)×(2r+1)` neighbourhood per pixel/channel.
837///
838/// Border pixels use clamped neighbour coordinates (replicate border).
839///
840/// # Errors
841///
842/// Returns an error if the buffer length does not match `width × height × channels`.
843pub fn median_filter(
844    data: &[u8],
845    width: u32,
846    height: u32,
847    channels: u32,
848    radius: u32,
849) -> crate::Result<Vec<u8>> {
850    let w = width as usize;
851    let h = height as usize;
852    let ch = channels as usize;
853    let expected = w * h * ch;
854    if data.len() != expected {
855        return Err(crate::GpuError::InvalidBufferSize {
856            expected,
857            actual: data.len(),
858        });
859    }
860    if w == 0 || h == 0 {
861        return Ok(data.to_vec());
862    }
863
864    let r = radius as isize;
865    let window_len = ((2 * r + 1) * (2 * r + 1)) as usize;
866    let mut output = vec![0u8; expected];
867
868    for row in 0..h {
869        for col in 0..w {
870            for c in 0..ch {
871                let mut window: Vec<u8> = Vec::with_capacity(window_len);
872                for dy in -r..=r {
873                    for dx in -r..=r {
874                        let sr = (row as isize + dy).clamp(0, h as isize - 1) as usize;
875                        let sc = (col as isize + dx).clamp(0, w as isize - 1) as usize;
876                        window.push(data[(sr * w + sc) * ch + c]);
877                    }
878                }
879                window.sort_unstable();
880                output[(row * w + col) * ch + c] = window[window.len() / 2];
881            }
882        }
883    }
884
885    Ok(output)
886}
887
888// ---------------------------------------------------------------------------
889// CPU-side bilateral filter wrapper (delegates to DenoiseOperation)
890// ---------------------------------------------------------------------------
891
892/// CPU bilateral filter: edge-preserving spatial filter.
893///
894/// Delegates to `denoise_bilateral_cpu` so
895/// there is a single canonical implementation.
896///
897/// # Errors
898///
899/// Returns an error if the buffer length does not match `width × height × channels`,
900/// or if `channels != 4` (the bilateral implementation is RGBA-only).
901pub fn bilateral_filter(
902    data: &[u8],
903    width: u32,
904    height: u32,
905    channels: u32,
906    sigma_spatial: f32,
907    sigma_range: f32,
908) -> crate::Result<Vec<u8>> {
909    if channels != 4 {
910        return Err(crate::GpuError::NotSupported(format!(
911            "bilateral_filter requires channels == 4, got {channels}"
912        )));
913    }
914    utils::validate_buffer_size(data, width, height, 4)?;
915    let mut output = vec![0u8; data.len()];
916    super::DenoiseOperation::denoise_bilateral_cpu(
917        data,
918        &mut output,
919        width,
920        height,
921        sigma_spatial,
922        sigma_range,
923    )?;
924    Ok(output)
925}
926
927/// Compare two RGBA u8 buffers and return the maximum absolute channel difference.
928///
929/// Useful for verifying that the separable serial and parallel implementations
930/// produce bit-identical (or near-identical) results.
931#[must_use]
932pub fn max_channel_diff(a: &[u8], b: &[u8]) -> u32 {
933    a.iter()
934        .zip(b.iter())
935        .map(|(&x, &y)| (x as i32 - y as i32).unsigned_abs())
936        .max()
937        .unwrap_or(0)
938}
939
940#[cfg(test)]
941mod tests {
942    use super::*;
943
944    #[test]
945    fn test_kernel_sums_to_one() {
946        let k = gaussian_kernel_1d(1.0);
947        let sum: f32 = k.iter().sum();
948        assert!((sum - 1.0).abs() < 1e-5, "kernel sum = {sum}");
949    }
950
951    #[test]
952    fn test_kernel_is_symmetric() {
953        let k = gaussian_kernel_1d(2.0);
954        let n = k.len();
955        for i in 0..n / 2 {
956            assert!(
957                (k[i] - k[n - 1 - i]).abs() < 1e-6,
958                "asymmetric at index {i}: {} vs {}",
959                k[i],
960                k[n - 1 - i]
961            );
962        }
963    }
964
965    #[test]
966    fn test_kernel_center_is_largest() {
967        let k = gaussian_kernel_1d(1.5);
968        let center = k[k.len() / 2];
969        for &v in &k {
970            assert!(center >= v, "center {center} not >= {v}");
971        }
972    }
973
974    #[test]
975    fn test_kernel_zero_sigma_returns_identity() {
976        let k = gaussian_kernel_1d(0.0);
977        assert_eq!(k.len(), 1);
978        assert!((k[0] - 1.0).abs() < 1e-6);
979    }
980
981    #[test]
982    fn test_kernel_negative_sigma_returns_identity() {
983        let k = gaussian_kernel_1d(-1.0);
984        assert_eq!(k.len(), 1);
985        assert!((k[0] - 1.0).abs() < 1e-6);
986    }
987
988    #[test]
989    fn test_blur_uniform_image_unchanged() {
990        let w = 8u32;
991        let h = 8u32;
992        let input: Vec<u8> = (0..(w * h * 4) as usize)
993            .map(|i| if i % 4 == 3 { 255 } else { 128 })
994            .collect();
995        let mut output = vec![0u8; (w * h * 4) as usize];
996        gaussian_blur_separable(&input, &mut output, w, h, 1.5).expect("blur should succeed");
997        for (i, (&inp, &out)) in input.iter().zip(output.iter()).enumerate() {
998            assert!(
999                (inp as i32 - out as i32).unsigned_abs() <= 1,
1000                "pixel {i}: input={inp} output={out}"
1001            );
1002        }
1003    }
1004
1005    #[test]
1006    fn test_blur_reduces_contrast() {
1007        let w = 4u32;
1008        let h = 4u32;
1009        let mut input = vec![0u8; (w * h * 4) as usize];
1010        for row in 0..h as usize {
1011            for col in 0..w as usize {
1012                let v = if (row + col) % 2 == 0 { 255u8 } else { 0u8 };
1013                let base = (row * w as usize + col) * 4;
1014                input[base] = v;
1015                input[base + 1] = v;
1016                input[base + 2] = v;
1017                input[base + 3] = 255;
1018            }
1019        }
1020        let mut output = vec![0u8; (w * h * 4) as usize];
1021        gaussian_blur_separable(&input, &mut output, w, h, 1.0).expect("blur should succeed");
1022        let max_rgb = output
1023            .chunks(4)
1024            .flat_map(|px| &px[..3])
1025            .copied()
1026            .max()
1027            .unwrap_or(0);
1028        assert!(
1029            max_rgb < 255,
1030            "max_rgb after blur = {max_rgb}; expected < 255"
1031        );
1032    }
1033
1034    #[test]
1035    fn test_blur_size_mismatch_returns_error() {
1036        let w = 4u32;
1037        let h = 4u32;
1038        let input = vec![0u8; (w * h * 4) as usize];
1039        let mut output = vec![0u8; 10];
1040        let result = gaussian_blur_separable(&input, &mut output, w, h, 1.0);
1041        assert!(result.is_err());
1042    }
1043
1044    #[test]
1045    fn test_blur_single_pixel_passthrough() {
1046        let input = vec![100u8, 150u8, 200u8, 255u8];
1047        let mut output = vec![0u8; 4];
1048        gaussian_blur_separable(&input, &mut output, 1, 1, 1.0).expect("blur should succeed");
1049        assert_eq!(output[0], 100);
1050        assert_eq!(output[1], 150);
1051        assert_eq!(output[2], 200);
1052        assert_eq!(output[3], 255);
1053    }
1054
1055    // ── Parallel blur tests ───────────────────────────────────────────────────
1056
1057    #[test]
1058    fn test_parallel_blur_matches_serial_uniform_image() {
1059        let w = 16u32;
1060        let h = 16u32;
1061        let input: Vec<u8> = vec![128u8; (w * h * 4) as usize];
1062        let mut serial = vec![0u8; (w * h * 4) as usize];
1063        let mut parallel = vec![0u8; (w * h * 4) as usize];
1064        gaussian_blur_separable(&input, &mut serial, w, h, 1.5).expect("serial blur");
1065        gaussian_blur_separable_parallel(&input, &mut parallel, w, h, 1.5).expect("parallel blur");
1066        assert_eq!(
1067            max_channel_diff(&serial, &parallel),
1068            0,
1069            "serial and parallel must agree on uniform image"
1070        );
1071    }
1072
1073    #[test]
1074    fn test_parallel_blur_matches_serial_random_image() {
1075        let w = 8u32;
1076        let h = 8u32;
1077        let input: Vec<u8> = (0..(w * h * 4) as usize)
1078            .map(|i| ((i * 37 + 13) % 256) as u8)
1079            .collect();
1080        let mut serial = vec![0u8; (w * h * 4) as usize];
1081        let mut parallel = vec![0u8; (w * h * 4) as usize];
1082        gaussian_blur_separable(&input, &mut serial, w, h, 1.0).expect("serial blur");
1083        gaussian_blur_separable_parallel(&input, &mut parallel, w, h, 1.0).expect("parallel blur");
1084        let max_diff = max_channel_diff(&serial, &parallel);
1085        assert_eq!(max_diff, 0, "serial and parallel outputs must be identical");
1086    }
1087
1088    #[test]
1089    fn test_parallel_blur_single_pixel_passthrough() {
1090        let input = vec![77u8, 88, 99, 255];
1091        let mut output = vec![0u8; 4];
1092        gaussian_blur_separable_parallel(&input, &mut output, 1, 1, 2.0)
1093            .expect("single pixel parallel blur");
1094        assert_eq!(output[0], 77);
1095        assert_eq!(output[1], 88);
1096        assert_eq!(output[2], 99);
1097        assert_eq!(output[3], 255);
1098    }
1099
1100    #[test]
1101    fn test_parallel_blur_size_mismatch_returns_error() {
1102        let input = vec![0u8; 4 * 4 * 4];
1103        let mut output = vec![0u8; 5]; // wrong
1104        let res = gaussian_blur_separable_parallel(&input, &mut output, 4, 4, 1.0);
1105        assert!(res.is_err());
1106    }
1107
1108    #[test]
1109    fn test_parallel_blur_reduces_contrast() {
1110        let w = 8u32;
1111        let h = 8u32;
1112        let mut input = vec![0u8; (w * h * 4) as usize];
1113        for row in 0..h as usize {
1114            for col in 0..w as usize {
1115                let v = if (row + col) % 2 == 0 { 255u8 } else { 0u8 };
1116                let base = (row * w as usize + col) * 4;
1117                input[base] = v;
1118                input[base + 1] = v;
1119                input[base + 2] = v;
1120                input[base + 3] = 255;
1121            }
1122        }
1123        let mut output = vec![0u8; (w * h * 4) as usize];
1124        gaussian_blur_separable_parallel(&input, &mut output, w, h, 1.5)
1125            .expect("parallel contrast blur");
1126        let max_rgb = output
1127            .chunks(4)
1128            .flat_map(|px| &px[..3])
1129            .copied()
1130            .max()
1131            .unwrap_or(0);
1132        assert!(
1133            max_rgb < 255,
1134            "parallel blur should reduce max brightness; got {max_rgb}"
1135        );
1136    }
1137
1138    #[test]
1139    fn test_parallel_blur_large_sigma_heavy_smoothing() {
1140        let w = 16u32;
1141        let h = 16u32;
1142        // Alternating black/white checkerboard
1143        let input: Vec<u8> = (0..(w * h) as usize)
1144            .flat_map(|i| {
1145                let row = i / w as usize;
1146                let col = i % w as usize;
1147                let v = if (row + col) % 2 == 0 { 255u8 } else { 0u8 };
1148                [v, v, v, 255u8]
1149            })
1150            .collect();
1151        let mut out_small = vec![0u8; (w * h * 4) as usize];
1152        let mut out_large = vec![0u8; (w * h * 4) as usize];
1153        gaussian_blur_separable_parallel(&input, &mut out_small, w, h, 0.5).expect("small sigma");
1154        gaussian_blur_separable_parallel(&input, &mut out_large, w, h, 3.0).expect("large sigma");
1155
1156        let range_small: u32 = out_small
1157            .chunks(4)
1158            .map(|px| px[0] as u32)
1159            .max()
1160            .unwrap_or(0)
1161            - out_small
1162                .chunks(4)
1163                .map(|px| px[0] as u32)
1164                .min()
1165                .unwrap_or(0);
1166        let range_large: u32 = out_large
1167            .chunks(4)
1168            .map(|px| px[0] as u32)
1169            .max()
1170            .unwrap_or(0)
1171            - out_large
1172                .chunks(4)
1173                .map(|px| px[0] as u32)
1174                .min()
1175                .unwrap_or(0);
1176        assert!(
1177            range_large <= range_small,
1178            "larger sigma should produce smaller contrast range; small={range_small}, large={range_large}"
1179        );
1180    }
1181
1182    #[test]
1183    fn test_parallel_blur_wide_image() {
1184        let w = 32u32;
1185        let h = 4u32;
1186        let input: Vec<u8> = (0..(w * h * 4) as usize).map(|i| (i % 256) as u8).collect();
1187        let mut output = vec![0u8; (w * h * 4) as usize];
1188        gaussian_blur_separable_parallel(&input, &mut output, w, h, 1.0)
1189            .expect("wide image parallel blur");
1190        assert_eq!(output.len(), (w * h * 4) as usize);
1191    }
1192
1193    #[test]
1194    fn test_parallel_blur_tall_image() {
1195        let w = 4u32;
1196        let h = 32u32;
1197        let input: Vec<u8> = (0..(w * h * 4) as usize).map(|i| (i % 256) as u8).collect();
1198        let mut output = vec![0u8; (w * h * 4) as usize];
1199        gaussian_blur_separable_parallel(&input, &mut output, w, h, 1.0)
1200            .expect("tall image parallel blur");
1201        assert_eq!(output.len(), (w * h * 4) as usize);
1202    }
1203
1204    #[test]
1205    fn test_max_channel_diff_identical() {
1206        let a = vec![128u8; 16];
1207        let diff = max_channel_diff(&a, &a);
1208        assert_eq!(diff, 0);
1209    }
1210
1211    #[test]
1212    fn test_max_channel_diff_known_values() {
1213        let a = vec![100u8, 200, 50, 255];
1214        let b = vec![90u8, 210, 50, 255];
1215        let diff = max_channel_diff(&a, &b);
1216        assert_eq!(diff, 10);
1217    }
1218
1219    // ── box_blur tests ────────────────────────────────────────────────────────
1220
1221    #[test]
1222    fn test_box_blur_uniform() {
1223        // A 4×4 image filled with value 128 (3 RGB channels) should pass through
1224        // unchanged — the box average of identical values is the same value.
1225        let w = 4u32;
1226        let h = 4u32;
1227        let ch = 3u32;
1228        let value: u8 = 128;
1229        let input = vec![value; (w * h * ch) as usize];
1230        let output = box_blur(&input, w, h, ch, 2).expect("box_blur should succeed");
1231        for (i, &v) in output.iter().enumerate() {
1232            assert!(
1233                (v as i32 - value as i32).abs() <= 1,
1234                "pixel byte {i}: expected {value}, got {v}"
1235            );
1236        }
1237    }
1238
1239    #[test]
1240    fn test_box_blur_spike() {
1241        // 7×7 dark image (value 0) with a single bright pixel at the centre.
1242        // After box blur with radius=1 the centre value should decrease and its
1243        // 8-neighbours should become > 0.
1244        let w = 7u32;
1245        let h = 7u32;
1246        let ch = 1u32;
1247        let mut input = vec![0u8; (w * h * ch) as usize];
1248        let cx = 3usize;
1249        let cy = 3usize;
1250        input[cy * w as usize + cx] = 255;
1251
1252        let output = box_blur(&input, w, h, ch, 1).expect("box_blur spike should succeed");
1253
1254        // Centre must be reduced.
1255        let centre = output[cy * w as usize + cx];
1256        assert!(
1257            centre < 255,
1258            "centre pixel should be reduced after box blur, got {centre}"
1259        );
1260
1261        // At least one immediate neighbour must be > 0.
1262        let right = output[cy * w as usize + cx + 1];
1263        let below = output[(cy + 1) * w as usize + cx];
1264        assert!(
1265            right > 0 || below > 0,
1266            "neighbours should receive energy; right={right}, below={below}"
1267        );
1268    }
1269
1270    #[test]
1271    fn test_box_blur_size_mismatch_returns_error() {
1272        // Buffer length that does not match w * h * ch should return an error.
1273        let result = box_blur(&[0u8; 10], 4, 4, 1, 1);
1274        assert!(result.is_err(), "expected error on size mismatch");
1275    }
1276
1277    // ── median_filter tests ───────────────────────────────────────────────────
1278
1279    #[test]
1280    fn test_median_removes_outlier() {
1281        // 5×5 single-channel image where all pixels are 100 except the centre,
1282        // which is 255. Median with radius=1 over a 3×3 window of 100-values
1283        // should remove the outlier (the median of [100,100,...,255] is 100).
1284        let w = 5u32;
1285        let h = 5u32;
1286        let ch = 1u32;
1287        let mut input = vec![100u8; (w * h * ch) as usize];
1288        let cx = 2usize;
1289        let cy = 2usize;
1290        input[cy * w as usize + cx] = 255; // outlier
1291
1292        let output = median_filter(&input, w, h, ch, 1).expect("median_filter should succeed");
1293
1294        let centre = output[cy * w as usize + cx];
1295        assert_eq!(
1296            centre, 100,
1297            "median should remove the outlier; centre={centre}"
1298        );
1299    }
1300
1301    #[test]
1302    fn test_median_uniform_image() {
1303        // Uniform image must be preserved exactly.
1304        let w = 4u32;
1305        let h = 4u32;
1306        let ch = 4u32;
1307        let input = vec![77u8; (w * h * ch) as usize];
1308        let output = median_filter(&input, w, h, ch, 2).expect("median_filter uniform");
1309        assert!(output.iter().all(|&v| v == 77));
1310    }
1311
1312    #[test]
1313    fn test_median_size_mismatch_returns_error() {
1314        let result = median_filter(&[0u8; 5], 4, 4, 1, 1);
1315        assert!(result.is_err(), "expected error on size mismatch");
1316    }
1317
1318    // ── bilateral_filter tests ────────────────────────────────────────────────
1319
1320    #[test]
1321    fn test_bilateral_edge_preserving() {
1322        // Left half of a 10×10 RGBA image is black (0), right half is white (255).
1323        // Bilateral filter with a large sigma_range should preserve the edge:
1324        // pixels well away from the boundary should stay close to their original value.
1325        let w = 10u32;
1326        let h = 10u32;
1327        let mut input = vec![0u8; (w * h * 4) as usize];
1328        for row in 0..h as usize {
1329            for col in 0..w as usize {
1330                let v: u8 = if col >= 5 { 255 } else { 0 };
1331                let base = (row * w as usize + col) * 4;
1332                input[base] = v;
1333                input[base + 1] = v;
1334                input[base + 2] = v;
1335                input[base + 3] = 255;
1336            }
1337        }
1338
1339        // sigma_spatial=2 (small neighbourhood), sigma_range=10 (tight range gate
1340        // ⇒ edge preserved well).
1341        let output =
1342            bilateral_filter(&input, w, h, 4, 2.0, 10.0).expect("bilateral_filter should succeed");
1343
1344        // Pixels in the interior of the black half should remain close to 0.
1345        for row in 0..h as usize {
1346            let col = 1usize; // well inside black half
1347            let base = (row * w as usize + col) * 4;
1348            for c in 0..3 {
1349                assert!(
1350                    output[base + c] < 64,
1351                    "row={row} col={col} ch={c}: expected near 0, got {}",
1352                    output[base + c]
1353                );
1354            }
1355        }
1356
1357        // Pixels in the interior of the white half should remain close to 255.
1358        for row in 0..h as usize {
1359            let col = 8usize; // well inside white half
1360            let base = (row * w as usize + col) * 4;
1361            for c in 0..3 {
1362                assert!(
1363                    output[base + c] > 191,
1364                    "row={row} col={col} ch={c}: expected near 255, got {}",
1365                    output[base + c]
1366                );
1367            }
1368        }
1369    }
1370
1371    #[test]
1372    fn test_bilateral_wrong_channels_returns_error() {
1373        let result = bilateral_filter(&[0u8; 9], 3, 3, 1, 2.0, 30.0);
1374        assert!(result.is_err(), "bilateral requires channels == 4");
1375    }
1376}