Skip to main content

oximedia_gpu/ops/
filter.rs

1//! Convolution filter operations (blur, sharpen, edge detection)
2
3use crate::{
4    shader::{BindGroupLayoutBuilder, ShaderCompiler, ShaderSource},
5    GpuDevice, GpuError, Result,
6};
7use bytemuck::{Pod, Zeroable};
8use once_cell::sync::OnceCell;
9use wgpu::{BindGroup, BindGroupLayout, ComputePipeline};
10
11use super::utils;
12
13#[repr(C)]
14#[derive(Copy, Clone, Pod, Zeroable)]
15struct FilterParams {
16    width: u32,
17    height: u32,
18    stride: u32,
19    kernel_size: u32,
20    normalize: u32,
21    filter_type: u32,
22    padding: u32,
23    sigma: f32,
24}
25
26/// Convolution filter operations
27pub struct FilterOperation;
28
29impl FilterOperation {
30    /// Apply Gaussian blur
31    ///
32    /// # Arguments
33    ///
34    /// * `device` - GPU device
35    /// * `input` - Input image buffer (packed RGBA format)
36    /// * `output` - Output image buffer (packed RGBA format)
37    /// * `width` - Image width
38    /// * `height` - Image height
39    /// * `sigma` - Blur radius (standard deviation)
40    ///
41    /// # Errors
42    ///
43    /// Returns an error if buffer sizes are invalid or if the GPU operation fails.
44    #[allow(clippy::too_many_arguments)]
45    pub fn gaussian_blur(
46        device: &GpuDevice,
47        input: &[u8],
48        output: &mut [u8],
49        width: u32,
50        height: u32,
51        sigma: f32,
52    ) -> Result<()> {
53        utils::validate_dimensions(width, height)?;
54        utils::validate_buffer_size(input, width, height, 4)?;
55        utils::validate_buffer_size(output, width, height, 4)?;
56
57        let kernel_size = Self::calculate_kernel_size(sigma);
58        let pipeline = Self::get_gaussian_pipeline(device)?;
59        let layout = Self::get_bind_group_layout(device)?;
60
61        Self::execute_filter(
62            device,
63            pipeline,
64            layout,
65            input,
66            output,
67            width,
68            height,
69            kernel_size,
70            1, // Gaussian filter type
71            sigma,
72        )
73    }
74
75    /// Apply sharpening filter (unsharp mask)
76    ///
77    /// # Arguments
78    ///
79    /// * `device` - GPU device
80    /// * `input` - Input image buffer (packed RGBA format)
81    /// * `output` - Output image buffer (packed RGBA format)
82    /// * `width` - Image width
83    /// * `height` - Image height
84    /// * `amount` - Sharpening strength
85    ///
86    /// # Errors
87    ///
88    /// Returns an error if buffer sizes are invalid or if the GPU operation fails.
89    #[allow(clippy::too_many_arguments)]
90    pub fn sharpen(
91        device: &GpuDevice,
92        input: &[u8],
93        output: &mut [u8],
94        width: u32,
95        height: u32,
96        amount: f32,
97    ) -> Result<()> {
98        utils::validate_dimensions(width, height)?;
99        utils::validate_buffer_size(input, width, height, 4)?;
100        utils::validate_buffer_size(output, width, height, 4)?;
101
102        let pipeline = Self::get_sharpen_pipeline(device)?;
103        let layout = Self::get_bind_group_layout(device)?;
104
105        Self::execute_filter(
106            device, pipeline, layout, input, output, width, height,
107            5, // Kernel size for sharpening
108            2, // Sharpen filter type
109            amount,
110        )
111    }
112
113    /// Detect edges using Sobel operator
114    ///
115    /// # Arguments
116    ///
117    /// * `device` - GPU device
118    /// * `input` - Input image buffer (packed RGBA format)
119    /// * `output` - Output image buffer (packed RGBA format)
120    /// * `width` - Image width
121    /// * `height` - Image height
122    ///
123    /// # Errors
124    ///
125    /// Returns an error if buffer sizes are invalid or if the GPU operation fails.
126    pub fn edge_detect(
127        device: &GpuDevice,
128        input: &[u8],
129        output: &mut [u8],
130        width: u32,
131        height: u32,
132    ) -> Result<()> {
133        utils::validate_dimensions(width, height)?;
134        utils::validate_buffer_size(input, width, height, 4)?;
135        utils::validate_buffer_size(output, width, height, 4)?;
136
137        let pipeline = Self::get_edge_detect_pipeline(device)?;
138        let layout = Self::get_bind_group_layout(device)?;
139
140        Self::execute_filter(
141            device, pipeline, layout, input, output, width, height, 3, // 3x3 Sobel kernel
142            3, // Edge detect filter type
143            0.0,
144        )
145    }
146
147    /// Apply custom convolution kernel
148    ///
149    /// # Arguments
150    ///
151    /// * `device` - GPU device
152    /// * `input` - Input image buffer (packed RGBA format)
153    /// * `output` - Output image buffer (packed RGBA format)
154    /// * `width` - Image width
155    /// * `height` - Image height
156    /// * `kernel` - Convolution kernel (must be square and odd-sized)
157    /// * `normalize` - Whether to normalize the kernel
158    ///
159    /// # Errors
160    ///
161    /// Returns an error if buffer sizes are invalid or if the GPU operation fails.
162    #[allow(clippy::too_many_arguments)]
163    pub fn convolve(
164        device: &GpuDevice,
165        input: &[u8],
166        output: &mut [u8],
167        width: u32,
168        height: u32,
169        kernel: &[f32],
170        normalize: bool,
171    ) -> Result<()> {
172        utils::validate_dimensions(width, height)?;
173        utils::validate_buffer_size(input, width, height, 4)?;
174        utils::validate_buffer_size(output, width, height, 4)?;
175
176        let kernel_size = (kernel.len() as f32).sqrt() as u32;
177        if kernel_size * kernel_size != kernel.len() as u32 {
178            return Err(GpuError::Internal("Kernel must be square".to_string()));
179        }
180        if kernel_size % 2 == 0 {
181            return Err(GpuError::Internal("Kernel size must be odd".to_string()));
182        }
183
184        let pipeline = Self::get_convolve_pipeline(device)?;
185        let layout = Self::get_bind_group_layout_with_kernel(device)?;
186
187        Self::execute_convolve(
188            device,
189            pipeline,
190            layout,
191            input,
192            output,
193            width,
194            height,
195            kernel,
196            kernel_size,
197            normalize,
198        )
199    }
200
201    #[allow(clippy::too_many_arguments)]
202    fn execute_filter(
203        device: &GpuDevice,
204        pipeline: &ComputePipeline,
205        layout: &BindGroupLayout,
206        input: &[u8],
207        output: &mut [u8],
208        width: u32,
209        height: u32,
210        kernel_size: u32,
211        filter_type: u32,
212        sigma: f32,
213    ) -> Result<()> {
214        // Create buffers
215        let input_buffer = utils::create_storage_buffer(device, input.len() as u64)?;
216        let output_buffer = utils::create_storage_buffer(device, output.len() as u64)?;
217
218        // Upload input data
219        device.queue().write_buffer(input_buffer.buffer(), 0, input);
220
221        // Create uniform buffer for parameters
222        let params = FilterParams {
223            width,
224            height,
225            stride: width,
226            kernel_size,
227            normalize: 1,
228            filter_type,
229            padding: 0,
230            sigma,
231        };
232        let params_bytes = bytemuck::bytes_of(&params);
233        let params_buffer = utils::create_uniform_buffer(device, params_bytes)?;
234
235        // Create bind group
236        let compiler = ShaderCompiler::new(device);
237        let bind_group = compiler.create_bind_group(
238            "Filter Bind Group",
239            layout,
240            &[
241                wgpu::BindGroupEntry {
242                    binding: 0,
243                    resource: input_buffer.buffer().as_entire_binding(),
244                },
245                wgpu::BindGroupEntry {
246                    binding: 1,
247                    resource: output_buffer.buffer().as_entire_binding(),
248                },
249                wgpu::BindGroupEntry {
250                    binding: 2,
251                    resource: params_buffer.buffer().as_entire_binding(),
252                },
253            ],
254        );
255
256        // Execute compute pass
257        Self::dispatch_compute(device, pipeline, &bind_group, width, height)?;
258
259        // Read back results
260        let readback_buffer = utils::create_readback_buffer(device, output.len() as u64)?;
261        let mut encoder = device
262            .device()
263            .create_command_encoder(&wgpu::CommandEncoderDescriptor {
264                label: Some("Filter Copy Encoder"),
265            });
266
267        output_buffer.copy_to(&mut encoder, &readback_buffer, 0, 0, output.len() as u64)?;
268
269        device.queue().submit(Some(encoder.finish()));
270        device.wait();
271
272        let result = readback_buffer.read(device, 0, output.len() as u64)?;
273        output.copy_from_slice(&result);
274
275        Ok(())
276    }
277
278    #[allow(clippy::too_many_arguments)]
279    fn execute_convolve(
280        device: &GpuDevice,
281        pipeline: &ComputePipeline,
282        layout: &BindGroupLayout,
283        input: &[u8],
284        output: &mut [u8],
285        width: u32,
286        height: u32,
287        kernel: &[f32],
288        kernel_size: u32,
289        normalize: bool,
290    ) -> Result<()> {
291        // Create buffers
292        let input_buffer = utils::create_storage_buffer(device, input.len() as u64)?;
293        let output_buffer = utils::create_storage_buffer(device, output.len() as u64)?;
294
295        // Upload input data
296        device.queue().write_buffer(input_buffer.buffer(), 0, input);
297
298        // Create kernel buffer
299        let kernel_bytes = bytemuck::cast_slice(kernel);
300        let kernel_buffer = utils::create_storage_buffer(device, kernel_bytes.len() as u64)?;
301        device
302            .queue()
303            .write_buffer(kernel_buffer.buffer(), 0, kernel_bytes);
304
305        // Create uniform buffer for parameters
306        let params = FilterParams {
307            width,
308            height,
309            stride: width,
310            kernel_size,
311            normalize: u32::from(normalize),
312            filter_type: 0, // Custom kernel
313            padding: 0,
314            sigma: 0.0,
315        };
316        let params_bytes = bytemuck::bytes_of(&params);
317        let params_buffer = utils::create_uniform_buffer(device, params_bytes)?;
318
319        // Create bind group
320        let compiler = ShaderCompiler::new(device);
321        let bind_group = compiler.create_bind_group(
322            "Filter Bind Group",
323            layout,
324            &[
325                wgpu::BindGroupEntry {
326                    binding: 0,
327                    resource: input_buffer.buffer().as_entire_binding(),
328                },
329                wgpu::BindGroupEntry {
330                    binding: 1,
331                    resource: output_buffer.buffer().as_entire_binding(),
332                },
333                wgpu::BindGroupEntry {
334                    binding: 2,
335                    resource: params_buffer.buffer().as_entire_binding(),
336                },
337                wgpu::BindGroupEntry {
338                    binding: 3,
339                    resource: kernel_buffer.buffer().as_entire_binding(),
340                },
341            ],
342        );
343
344        // Execute compute pass
345        Self::dispatch_compute(device, pipeline, &bind_group, width, height)?;
346
347        // Read back results
348        let readback_buffer = utils::create_readback_buffer(device, output.len() as u64)?;
349        let mut encoder = device
350            .device()
351            .create_command_encoder(&wgpu::CommandEncoderDescriptor {
352                label: Some("Filter Copy Encoder"),
353            });
354
355        output_buffer.copy_to(&mut encoder, &readback_buffer, 0, 0, output.len() as u64)?;
356
357        device.queue().submit(Some(encoder.finish()));
358        device.wait();
359
360        let result = readback_buffer.read(device, 0, output.len() as u64)?;
361        output.copy_from_slice(&result);
362
363        Ok(())
364    }
365
366    fn dispatch_compute(
367        device: &GpuDevice,
368        pipeline: &ComputePipeline,
369        bind_group: &BindGroup,
370        width: u32,
371        height: u32,
372    ) -> Result<()> {
373        let mut encoder = device
374            .device()
375            .create_command_encoder(&wgpu::CommandEncoderDescriptor {
376                label: Some("Filter Compute Encoder"),
377            });
378
379        {
380            let mut compute_pass = encoder.begin_compute_pass(&wgpu::ComputePassDescriptor {
381                label: Some("Filter Compute Pass"),
382                timestamp_writes: None,
383            });
384
385            compute_pass.set_pipeline(pipeline);
386            compute_pass.set_bind_group(0, bind_group, &[]);
387
388            let (dispatch_x, dispatch_y) = utils::calculate_dispatch_size(width, height, (16, 16));
389            compute_pass.dispatch_workgroups(dispatch_x, dispatch_y, 1);
390        }
391
392        device.queue().submit(Some(encoder.finish()));
393        Ok(())
394    }
395
396    fn calculate_kernel_size(sigma: f32) -> u32 {
397        // Use 3-sigma rule: kernel size = 2 * ceil(3 * sigma) + 1
398        let radius = (3.0 * sigma).ceil() as u32;
399        2 * radius + 1
400    }
401
402    fn get_bind_group_layout(device: &GpuDevice) -> Result<&'static BindGroupLayout> {
403        static LAYOUT: OnceCell<BindGroupLayout> = OnceCell::new();
404
405        Ok(LAYOUT.get_or_init(|| {
406            let compiler = ShaderCompiler::new(device);
407            let entries = BindGroupLayoutBuilder::new()
408                .add_storage_buffer_read_only(0) // input
409                .add_storage_buffer(1) // output
410                .add_uniform_buffer(2) // params
411                .build();
412
413            compiler.create_bind_group_layout("Filter Bind Group Layout", &entries)
414        }))
415    }
416
417    fn get_bind_group_layout_with_kernel(device: &GpuDevice) -> Result<&'static BindGroupLayout> {
418        static LAYOUT: OnceCell<BindGroupLayout> = OnceCell::new();
419
420        Ok(LAYOUT.get_or_init(|| {
421            let compiler = ShaderCompiler::new(device);
422            let entries = BindGroupLayoutBuilder::new()
423                .add_storage_buffer_read_only(0) // input
424                .add_storage_buffer(1) // output
425                .add_uniform_buffer(2) // params
426                .add_storage_buffer_read_only(3) // kernel
427                .build();
428
429            compiler.create_bind_group_layout("Filter Bind Group Layout (with kernel)", &entries)
430        }))
431    }
432
433    fn init_pipeline(
434        device: &GpuDevice,
435        name: &str,
436        entry_point: &str,
437        layout_fn: fn(&GpuDevice) -> Result<&'static BindGroupLayout>,
438    ) -> std::result::Result<ComputePipeline, String> {
439        let compiler = ShaderCompiler::new(device);
440        let shader = compiler
441            .compile(
442                "Filter Shader",
443                ShaderSource::Embedded(crate::shader::embedded::FILTER_SHADER),
444            )
445            .map_err(|e| format!("Failed to compile filter shader: {e}"))?;
446
447        let layout =
448            layout_fn(device).map_err(|e| format!("Failed to create bind group layout: {e}"))?;
449
450        compiler
451            .create_pipeline(name, &shader, entry_point, layout)
452            .map_err(|e| format!("Failed to create pipeline: {e}"))
453    }
454
455    fn get_gaussian_pipeline(device: &GpuDevice) -> Result<&'static ComputePipeline> {
456        static PIPELINE: OnceCell<std::result::Result<ComputePipeline, String>> = OnceCell::new();
457
458        PIPELINE
459            .get_or_init(|| {
460                FilterOperation::init_pipeline(
461                    device,
462                    "Gaussian Blur Pipeline",
463                    "convolve_main",
464                    Self::get_bind_group_layout,
465                )
466            })
467            .as_ref()
468            .map_err(|e| crate::GpuError::PipelineCreation(e.clone()))
469    }
470
471    fn get_sharpen_pipeline(device: &GpuDevice) -> Result<&'static ComputePipeline> {
472        static PIPELINE: OnceCell<std::result::Result<ComputePipeline, String>> = OnceCell::new();
473
474        PIPELINE
475            .get_or_init(|| {
476                FilterOperation::init_pipeline(
477                    device,
478                    "Sharpen Pipeline",
479                    "unsharp_mask",
480                    Self::get_bind_group_layout,
481                )
482            })
483            .as_ref()
484            .map_err(|e| crate::GpuError::PipelineCreation(e.clone()))
485    }
486
487    fn get_edge_detect_pipeline(device: &GpuDevice) -> Result<&'static ComputePipeline> {
488        static PIPELINE: OnceCell<std::result::Result<ComputePipeline, String>> = OnceCell::new();
489
490        PIPELINE
491            .get_or_init(|| {
492                FilterOperation::init_pipeline(
493                    device,
494                    "Edge Detect Pipeline",
495                    "edge_detect",
496                    Self::get_bind_group_layout,
497                )
498            })
499            .as_ref()
500            .map_err(|e| crate::GpuError::PipelineCreation(e.clone()))
501    }
502
503    fn get_convolve_pipeline(device: &GpuDevice) -> Result<&'static ComputePipeline> {
504        static PIPELINE: OnceCell<std::result::Result<ComputePipeline, String>> = OnceCell::new();
505
506        PIPELINE
507            .get_or_init(|| {
508                FilterOperation::init_pipeline(
509                    device,
510                    "Convolve Pipeline",
511                    "convolve_main",
512                    Self::get_bind_group_layout_with_kernel,
513                )
514            })
515            .as_ref()
516            .map_err(|e| crate::GpuError::PipelineCreation(e.clone()))
517    }
518}
519
520// ---------------------------------------------------------------------------
521// Separable CPU Gaussian blur (Task 9)
522// ---------------------------------------------------------------------------
523
524/// Build a normalised 1-D Gaussian kernel of radius `ceil(3σ)`.
525///
526/// The returned `Vec<f32>` has length `2*radius+1` and sums to 1.0.
527/// If `sigma` is ≤ 0 a single-element identity kernel `[1.0]` is returned.
528#[must_use]
529pub fn gaussian_kernel_1d(sigma: f32) -> Vec<f32> {
530    if sigma <= 0.0 {
531        return vec![1.0_f32];
532    }
533    let radius = (3.0 * sigma).ceil() as usize;
534    let len = 2 * radius + 1;
535    let mut kernel = Vec::with_capacity(len);
536    let two_sigma_sq = 2.0 * sigma * sigma;
537    let mut sum = 0.0_f32;
538    for i in 0..len {
539        let x = i as f32 - radius as f32;
540        let v = (-x * x / two_sigma_sq).exp();
541        kernel.push(v);
542        sum += v;
543    }
544    for k in &mut kernel {
545        *k /= sum;
546    }
547    kernel
548}
549
550/// CPU-side separable Gaussian blur (two-pass: horizontal then vertical).
551///
552/// # Arguments
553///
554/// * `input`  — source RGBA bytes (`width × height × 4`)
555/// * `output` — destination RGBA bytes (same size as input)
556/// * `width`, `height` — image dimensions
557/// * `sigma`  — Gaussian standard deviation in pixels (> 0)
558///
559/// # Errors
560///
561/// Returns an error if buffer sizes do not match `width × height × 4`.
562pub fn gaussian_blur_separable(
563    input: &[u8],
564    output: &mut [u8],
565    width: u32,
566    height: u32,
567    sigma: f32,
568) -> crate::Result<()> {
569    utils::validate_dimensions(width, height)?;
570    utils::validate_buffer_size(input, width, height, 4)?;
571    utils::validate_buffer_size(output, width, height, 4)?;
572
573    let w = width as usize;
574    let h = height as usize;
575    let kernel = gaussian_kernel_1d(sigma);
576    let radius = kernel.len() / 2;
577
578    // Horizontal pass — accumulate into f32 buffer to avoid clamping artefacts
579    let mut h_pass = vec![0.0_f32; w * h * 4];
580    for row in 0..h {
581        for col in 0..w {
582            let mut acc = [0.0_f32; 4];
583            let mut wsum = 0.0_f32;
584            for (ki, &kw) in kernel.iter().enumerate() {
585                let sc = col as isize + ki as isize - radius as isize;
586                if sc < 0 || sc >= w as isize {
587                    continue;
588                }
589                let src = (row * w + sc as usize) * 4;
590                for c in 0..4 {
591                    acc[c] += kw * input[src + c] as f32;
592                }
593                wsum += kw;
594            }
595            let dst = (row * w + col) * 4;
596            let inv = if wsum > 0.0 { 1.0 / wsum } else { 1.0 };
597            for c in 0..4 {
598                h_pass[dst + c] = acc[c] * inv;
599            }
600        }
601    }
602
603    // Vertical pass — read from h_pass, write to output
604    for row in 0..h {
605        for col in 0..w {
606            let mut acc = [0.0_f32; 4];
607            let mut wsum = 0.0_f32;
608            for (ki, &kw) in kernel.iter().enumerate() {
609                let sr = row as isize + ki as isize - radius as isize;
610                if sr < 0 || sr >= h as isize {
611                    continue;
612                }
613                let src = (sr as usize * w + col) * 4;
614                for c in 0..4 {
615                    acc[c] += kw * h_pass[src + c];
616                }
617                wsum += kw;
618            }
619            let dst = (row * w + col) * 4;
620            let inv = if wsum > 0.0 { 1.0 / wsum } else { 1.0 };
621            for c in 0..4 {
622                output[dst + c] = (acc[c] * inv).round().clamp(0.0, 255.0) as u8;
623            }
624        }
625    }
626
627    Ok(())
628}
629
630use rayon::prelude::*;
631
632/// CPU-side separable Gaussian blur with Rayon parallel row/column processing.
633///
634/// This is an optimised variant of [`gaussian_blur_separable`] that
635/// parallelises both the horizontal and vertical passes using Rayon.
636///
637/// # Arguments
638///
639/// * `input`  — source RGBA bytes (`width × height × 4`)
640/// * `output` — destination RGBA bytes (same size as input)
641/// * `width`, `height` — image dimensions
642/// * `sigma`  — Gaussian standard deviation in pixels (> 0)
643///
644/// # Errors
645///
646/// Returns an error if buffer sizes do not match `width × height × 4`.
647pub fn gaussian_blur_separable_parallel(
648    input: &[u8],
649    output: &mut [u8],
650    width: u32,
651    height: u32,
652    sigma: f32,
653) -> crate::Result<()> {
654    utils::validate_dimensions(width, height)?;
655    utils::validate_buffer_size(input, width, height, 4)?;
656    utils::validate_buffer_size(output, width, height, 4)?;
657
658    let w = width as usize;
659    let h = height as usize;
660    let kernel = gaussian_kernel_1d(sigma);
661    let radius = kernel.len() / 2;
662
663    // Horizontal pass (parallel over rows)
664    let mut h_pass = vec![0.0_f32; w * h * 4];
665    h_pass
666        .par_chunks_exact_mut(w * 4)
667        .enumerate()
668        .for_each(|(row, row_out)| {
669            for col in 0..w {
670                let mut acc = [0.0_f32; 4];
671                let mut wsum = 0.0_f32;
672                for (ki, &kw) in kernel.iter().enumerate() {
673                    let sc = col as isize + ki as isize - radius as isize;
674                    if sc < 0 || sc >= w as isize {
675                        continue;
676                    }
677                    let src = (row * w + sc as usize) * 4;
678                    for c in 0..4 {
679                        acc[c] += kw * input[src + c] as f32;
680                    }
681                    wsum += kw;
682                }
683                let inv = if wsum > 0.0 { 1.0 / wsum } else { 1.0 };
684                let dst = col * 4;
685                for c in 0..4 {
686                    row_out[dst + c] = acc[c] * inv;
687                }
688            }
689        });
690
691    // Vertical pass (parallel over columns)
692    output
693        .par_chunks_exact_mut(4)
694        .enumerate()
695        .for_each(|(px_idx, px_out)| {
696            let row = px_idx / w;
697            let col = px_idx % w;
698            let mut acc = [0.0_f32; 4];
699            let mut wsum = 0.0_f32;
700            for (ki, &kw) in kernel.iter().enumerate() {
701                let sr = row as isize + ki as isize - radius as isize;
702                if sr < 0 || sr >= h as isize {
703                    continue;
704                }
705                let src = (sr as usize * w + col) * 4;
706                for c in 0..4 {
707                    acc[c] += kw * h_pass[src + c];
708                }
709                wsum += kw;
710            }
711            let inv = if wsum > 0.0 { 1.0 / wsum } else { 1.0 };
712            for c in 0..4 {
713                px_out[c] = (acc[c] * inv).round().clamp(0.0, 255.0) as u8;
714            }
715        });
716
717    Ok(())
718}
719
720/// Compare two RGBA u8 buffers and return the maximum absolute channel difference.
721///
722/// Useful for verifying that the separable serial and parallel implementations
723/// produce bit-identical (or near-identical) results.
724#[must_use]
725pub fn max_channel_diff(a: &[u8], b: &[u8]) -> u32 {
726    a.iter()
727        .zip(b.iter())
728        .map(|(&x, &y)| (x as i32 - y as i32).unsigned_abs())
729        .max()
730        .unwrap_or(0)
731}
732
733#[cfg(test)]
734mod tests {
735    use super::*;
736
737    #[test]
738    fn test_kernel_sums_to_one() {
739        let k = gaussian_kernel_1d(1.0);
740        let sum: f32 = k.iter().sum();
741        assert!((sum - 1.0).abs() < 1e-5, "kernel sum = {sum}");
742    }
743
744    #[test]
745    fn test_kernel_is_symmetric() {
746        let k = gaussian_kernel_1d(2.0);
747        let n = k.len();
748        for i in 0..n / 2 {
749            assert!(
750                (k[i] - k[n - 1 - i]).abs() < 1e-6,
751                "asymmetric at index {i}: {} vs {}",
752                k[i],
753                k[n - 1 - i]
754            );
755        }
756    }
757
758    #[test]
759    fn test_kernel_center_is_largest() {
760        let k = gaussian_kernel_1d(1.5);
761        let center = k[k.len() / 2];
762        for &v in &k {
763            assert!(center >= v, "center {center} not >= {v}");
764        }
765    }
766
767    #[test]
768    fn test_kernel_zero_sigma_returns_identity() {
769        let k = gaussian_kernel_1d(0.0);
770        assert_eq!(k.len(), 1);
771        assert!((k[0] - 1.0).abs() < 1e-6);
772    }
773
774    #[test]
775    fn test_kernel_negative_sigma_returns_identity() {
776        let k = gaussian_kernel_1d(-1.0);
777        assert_eq!(k.len(), 1);
778        assert!((k[0] - 1.0).abs() < 1e-6);
779    }
780
781    #[test]
782    fn test_blur_uniform_image_unchanged() {
783        let w = 8u32;
784        let h = 8u32;
785        let input: Vec<u8> = (0..(w * h * 4) as usize)
786            .map(|i| if i % 4 == 3 { 255 } else { 128 })
787            .collect();
788        let mut output = vec![0u8; (w * h * 4) as usize];
789        gaussian_blur_separable(&input, &mut output, w, h, 1.5).expect("blur should succeed");
790        for (i, (&inp, &out)) in input.iter().zip(output.iter()).enumerate() {
791            assert!(
792                (inp as i32 - out as i32).unsigned_abs() <= 1,
793                "pixel {i}: input={inp} output={out}"
794            );
795        }
796    }
797
798    #[test]
799    fn test_blur_reduces_contrast() {
800        let w = 4u32;
801        let h = 4u32;
802        let mut input = vec![0u8; (w * h * 4) as usize];
803        for row in 0..h as usize {
804            for col in 0..w as usize {
805                let v = if (row + col) % 2 == 0 { 255u8 } else { 0u8 };
806                let base = (row * w as usize + col) * 4;
807                input[base] = v;
808                input[base + 1] = v;
809                input[base + 2] = v;
810                input[base + 3] = 255;
811            }
812        }
813        let mut output = vec![0u8; (w * h * 4) as usize];
814        gaussian_blur_separable(&input, &mut output, w, h, 1.0).expect("blur should succeed");
815        let max_rgb = output
816            .chunks(4)
817            .flat_map(|px| &px[..3])
818            .copied()
819            .max()
820            .unwrap_or(0);
821        assert!(
822            max_rgb < 255,
823            "max_rgb after blur = {max_rgb}; expected < 255"
824        );
825    }
826
827    #[test]
828    fn test_blur_size_mismatch_returns_error() {
829        let w = 4u32;
830        let h = 4u32;
831        let input = vec![0u8; (w * h * 4) as usize];
832        let mut output = vec![0u8; 10];
833        let result = gaussian_blur_separable(&input, &mut output, w, h, 1.0);
834        assert!(result.is_err());
835    }
836
837    #[test]
838    fn test_blur_single_pixel_passthrough() {
839        let input = vec![100u8, 150u8, 200u8, 255u8];
840        let mut output = vec![0u8; 4];
841        gaussian_blur_separable(&input, &mut output, 1, 1, 1.0).expect("blur should succeed");
842        assert_eq!(output[0], 100);
843        assert_eq!(output[1], 150);
844        assert_eq!(output[2], 200);
845        assert_eq!(output[3], 255);
846    }
847
848    // ── Parallel blur tests ───────────────────────────────────────────────────
849
850    #[test]
851    fn test_parallel_blur_matches_serial_uniform_image() {
852        let w = 16u32;
853        let h = 16u32;
854        let input: Vec<u8> = vec![128u8; (w * h * 4) as usize];
855        let mut serial = vec![0u8; (w * h * 4) as usize];
856        let mut parallel = vec![0u8; (w * h * 4) as usize];
857        gaussian_blur_separable(&input, &mut serial, w, h, 1.5).expect("serial blur");
858        gaussian_blur_separable_parallel(&input, &mut parallel, w, h, 1.5).expect("parallel blur");
859        assert_eq!(
860            max_channel_diff(&serial, &parallel),
861            0,
862            "serial and parallel must agree on uniform image"
863        );
864    }
865
866    #[test]
867    fn test_parallel_blur_matches_serial_random_image() {
868        let w = 8u32;
869        let h = 8u32;
870        let input: Vec<u8> = (0..(w * h * 4) as usize)
871            .map(|i| ((i * 37 + 13) % 256) as u8)
872            .collect();
873        let mut serial = vec![0u8; (w * h * 4) as usize];
874        let mut parallel = vec![0u8; (w * h * 4) as usize];
875        gaussian_blur_separable(&input, &mut serial, w, h, 1.0).expect("serial blur");
876        gaussian_blur_separable_parallel(&input, &mut parallel, w, h, 1.0).expect("parallel blur");
877        let max_diff = max_channel_diff(&serial, &parallel);
878        assert_eq!(max_diff, 0, "serial and parallel outputs must be identical");
879    }
880
881    #[test]
882    fn test_parallel_blur_single_pixel_passthrough() {
883        let input = vec![77u8, 88, 99, 255];
884        let mut output = vec![0u8; 4];
885        gaussian_blur_separable_parallel(&input, &mut output, 1, 1, 2.0)
886            .expect("single pixel parallel blur");
887        assert_eq!(output[0], 77);
888        assert_eq!(output[1], 88);
889        assert_eq!(output[2], 99);
890        assert_eq!(output[3], 255);
891    }
892
893    #[test]
894    fn test_parallel_blur_size_mismatch_returns_error() {
895        let input = vec![0u8; 4 * 4 * 4];
896        let mut output = vec![0u8; 5]; // wrong
897        let res = gaussian_blur_separable_parallel(&input, &mut output, 4, 4, 1.0);
898        assert!(res.is_err());
899    }
900
901    #[test]
902    fn test_parallel_blur_reduces_contrast() {
903        let w = 8u32;
904        let h = 8u32;
905        let mut input = vec![0u8; (w * h * 4) as usize];
906        for row in 0..h as usize {
907            for col in 0..w as usize {
908                let v = if (row + col) % 2 == 0 { 255u8 } else { 0u8 };
909                let base = (row * w as usize + col) * 4;
910                input[base] = v;
911                input[base + 1] = v;
912                input[base + 2] = v;
913                input[base + 3] = 255;
914            }
915        }
916        let mut output = vec![0u8; (w * h * 4) as usize];
917        gaussian_blur_separable_parallel(&input, &mut output, w, h, 1.5)
918            .expect("parallel contrast blur");
919        let max_rgb = output
920            .chunks(4)
921            .flat_map(|px| &px[..3])
922            .copied()
923            .max()
924            .unwrap_or(0);
925        assert!(
926            max_rgb < 255,
927            "parallel blur should reduce max brightness; got {max_rgb}"
928        );
929    }
930
931    #[test]
932    fn test_parallel_blur_large_sigma_heavy_smoothing() {
933        let w = 16u32;
934        let h = 16u32;
935        // Alternating black/white checkerboard
936        let input: Vec<u8> = (0..(w * h) as usize)
937            .flat_map(|i| {
938                let row = i / w as usize;
939                let col = i % w as usize;
940                let v = if (row + col) % 2 == 0 { 255u8 } else { 0u8 };
941                [v, v, v, 255u8]
942            })
943            .collect();
944        let mut out_small = vec![0u8; (w * h * 4) as usize];
945        let mut out_large = vec![0u8; (w * h * 4) as usize];
946        gaussian_blur_separable_parallel(&input, &mut out_small, w, h, 0.5).expect("small sigma");
947        gaussian_blur_separable_parallel(&input, &mut out_large, w, h, 3.0).expect("large sigma");
948
949        let range_small: u32 = out_small
950            .chunks(4)
951            .map(|px| px[0] as u32)
952            .max()
953            .unwrap_or(0)
954            - out_small
955                .chunks(4)
956                .map(|px| px[0] as u32)
957                .min()
958                .unwrap_or(0);
959        let range_large: u32 = out_large
960            .chunks(4)
961            .map(|px| px[0] as u32)
962            .max()
963            .unwrap_or(0)
964            - out_large
965                .chunks(4)
966                .map(|px| px[0] as u32)
967                .min()
968                .unwrap_or(0);
969        assert!(
970            range_large <= range_small,
971            "larger sigma should produce smaller contrast range; small={range_small}, large={range_large}"
972        );
973    }
974
975    #[test]
976    fn test_parallel_blur_wide_image() {
977        let w = 32u32;
978        let h = 4u32;
979        let input: Vec<u8> = (0..(w * h * 4) as usize).map(|i| (i % 256) as u8).collect();
980        let mut output = vec![0u8; (w * h * 4) as usize];
981        gaussian_blur_separable_parallel(&input, &mut output, w, h, 1.0)
982            .expect("wide image parallel blur");
983        assert_eq!(output.len(), (w * h * 4) as usize);
984    }
985
986    #[test]
987    fn test_parallel_blur_tall_image() {
988        let w = 4u32;
989        let h = 32u32;
990        let input: Vec<u8> = (0..(w * h * 4) as usize).map(|i| (i % 256) as u8).collect();
991        let mut output = vec![0u8; (w * h * 4) as usize];
992        gaussian_blur_separable_parallel(&input, &mut output, w, h, 1.0)
993            .expect("tall image parallel blur");
994        assert_eq!(output.len(), (w * h * 4) as usize);
995    }
996
997    #[test]
998    fn test_max_channel_diff_identical() {
999        let a = vec![128u8; 16];
1000        let diff = max_channel_diff(&a, &a);
1001        assert_eq!(diff, 0);
1002    }
1003
1004    #[test]
1005    fn test_max_channel_diff_known_values() {
1006        let a = vec![100u8, 200, 50, 255];
1007        let b = vec![90u8, 210, 50, 255];
1008        let diff = max_channel_diff(&a, &b);
1009        assert_eq!(diff, 10);
1010    }
1011}