1use crate::{
4 shader::{BindGroupLayoutBuilder, ShaderCompiler, ShaderSource},
5 GpuDevice, GpuError, Result,
6};
7use bytemuck::{Pod, Zeroable};
8use once_cell::sync::OnceCell;
9use wgpu::{BindGroup, BindGroupLayout, ComputePipeline};
10
11use super::utils;
12
13#[repr(C)]
14#[derive(Copy, Clone, Pod, Zeroable)]
15struct FilterParams {
16 width: u32,
17 height: u32,
18 stride: u32,
19 kernel_size: u32,
20 normalize: u32,
21 filter_type: u32,
22 padding: u32,
23 sigma: f32,
24}
25
26pub struct FilterOperation;
28
29impl FilterOperation {
30 #[allow(clippy::too_many_arguments)]
45 pub fn gaussian_blur(
46 device: &GpuDevice,
47 input: &[u8],
48 output: &mut [u8],
49 width: u32,
50 height: u32,
51 sigma: f32,
52 ) -> Result<()> {
53 utils::validate_dimensions(width, height)?;
54 utils::validate_buffer_size(input, width, height, 4)?;
55 utils::validate_buffer_size(output, width, height, 4)?;
56
57 let kernel_size = Self::calculate_kernel_size(sigma);
58 let pipeline = Self::get_gaussian_pipeline(device)?;
59 let layout = Self::get_bind_group_layout(device)?;
60
61 Self::execute_filter(
62 device,
63 pipeline,
64 layout,
65 input,
66 output,
67 width,
68 height,
69 kernel_size,
70 1, sigma,
72 )
73 }
74
75 #[allow(clippy::too_many_arguments)]
90 pub fn sharpen(
91 device: &GpuDevice,
92 input: &[u8],
93 output: &mut [u8],
94 width: u32,
95 height: u32,
96 amount: f32,
97 ) -> Result<()> {
98 utils::validate_dimensions(width, height)?;
99 utils::validate_buffer_size(input, width, height, 4)?;
100 utils::validate_buffer_size(output, width, height, 4)?;
101
102 let pipeline = Self::get_sharpen_pipeline(device)?;
103 let layout = Self::get_bind_group_layout(device)?;
104
105 Self::execute_filter(
106 device, pipeline, layout, input, output, width, height,
107 5, 2, amount,
110 )
111 }
112
113 pub fn edge_detect(
127 device: &GpuDevice,
128 input: &[u8],
129 output: &mut [u8],
130 width: u32,
131 height: u32,
132 ) -> Result<()> {
133 utils::validate_dimensions(width, height)?;
134 utils::validate_buffer_size(input, width, height, 4)?;
135 utils::validate_buffer_size(output, width, height, 4)?;
136
137 let pipeline = Self::get_edge_detect_pipeline(device)?;
138 let layout = Self::get_bind_group_layout(device)?;
139
140 Self::execute_filter(
141 device, pipeline, layout, input, output, width, height, 3, 3, 0.0,
144 )
145 }
146
147 #[allow(clippy::too_many_arguments)]
163 pub fn convolve(
164 device: &GpuDevice,
165 input: &[u8],
166 output: &mut [u8],
167 width: u32,
168 height: u32,
169 kernel: &[f32],
170 normalize: bool,
171 ) -> Result<()> {
172 utils::validate_dimensions(width, height)?;
173 utils::validate_buffer_size(input, width, height, 4)?;
174 utils::validate_buffer_size(output, width, height, 4)?;
175
176 let kernel_size = (kernel.len() as f32).sqrt() as u32;
177 if kernel_size * kernel_size != kernel.len() as u32 {
178 return Err(GpuError::Internal("Kernel must be square".to_string()));
179 }
180 if kernel_size % 2 == 0 {
181 return Err(GpuError::Internal("Kernel size must be odd".to_string()));
182 }
183
184 let pipeline = Self::get_convolve_pipeline(device)?;
185 let layout = Self::get_bind_group_layout_with_kernel(device)?;
186
187 Self::execute_convolve(
188 device,
189 pipeline,
190 layout,
191 input,
192 output,
193 width,
194 height,
195 kernel,
196 kernel_size,
197 normalize,
198 )
199 }
200
201 #[allow(clippy::too_many_arguments)]
202 fn execute_filter(
203 device: &GpuDevice,
204 pipeline: &ComputePipeline,
205 layout: &BindGroupLayout,
206 input: &[u8],
207 output: &mut [u8],
208 width: u32,
209 height: u32,
210 kernel_size: u32,
211 filter_type: u32,
212 sigma: f32,
213 ) -> Result<()> {
214 let input_buffer = utils::create_storage_buffer(device, input.len() as u64)?;
216 let output_buffer = utils::create_storage_buffer(device, output.len() as u64)?;
217
218 device.queue().write_buffer(input_buffer.buffer(), 0, input);
220
221 let params = FilterParams {
223 width,
224 height,
225 stride: width,
226 kernel_size,
227 normalize: 1,
228 filter_type,
229 padding: 0,
230 sigma,
231 };
232 let params_bytes = bytemuck::bytes_of(¶ms);
233 let params_buffer = utils::create_uniform_buffer(device, params_bytes)?;
234
235 let compiler = ShaderCompiler::new(device);
237 let bind_group = compiler.create_bind_group(
238 "Filter Bind Group",
239 layout,
240 &[
241 wgpu::BindGroupEntry {
242 binding: 0,
243 resource: input_buffer.buffer().as_entire_binding(),
244 },
245 wgpu::BindGroupEntry {
246 binding: 1,
247 resource: output_buffer.buffer().as_entire_binding(),
248 },
249 wgpu::BindGroupEntry {
250 binding: 2,
251 resource: params_buffer.buffer().as_entire_binding(),
252 },
253 ],
254 );
255
256 Self::dispatch_compute(device, pipeline, &bind_group, width, height)?;
258
259 let readback_buffer = utils::create_readback_buffer(device, output.len() as u64)?;
261 let mut encoder = device
262 .device()
263 .create_command_encoder(&wgpu::CommandEncoderDescriptor {
264 label: Some("Filter Copy Encoder"),
265 });
266
267 output_buffer.copy_to(&mut encoder, &readback_buffer, 0, 0, output.len() as u64)?;
268
269 device.queue().submit(Some(encoder.finish()));
270 device.wait();
271
272 let result = readback_buffer.read(device, 0, output.len() as u64)?;
273 output.copy_from_slice(&result);
274
275 Ok(())
276 }
277
278 #[allow(clippy::too_many_arguments)]
279 fn execute_convolve(
280 device: &GpuDevice,
281 pipeline: &ComputePipeline,
282 layout: &BindGroupLayout,
283 input: &[u8],
284 output: &mut [u8],
285 width: u32,
286 height: u32,
287 kernel: &[f32],
288 kernel_size: u32,
289 normalize: bool,
290 ) -> Result<()> {
291 let input_buffer = utils::create_storage_buffer(device, input.len() as u64)?;
293 let output_buffer = utils::create_storage_buffer(device, output.len() as u64)?;
294
295 device.queue().write_buffer(input_buffer.buffer(), 0, input);
297
298 let kernel_bytes = bytemuck::cast_slice(kernel);
300 let kernel_buffer = utils::create_storage_buffer(device, kernel_bytes.len() as u64)?;
301 device
302 .queue()
303 .write_buffer(kernel_buffer.buffer(), 0, kernel_bytes);
304
305 let params = FilterParams {
307 width,
308 height,
309 stride: width,
310 kernel_size,
311 normalize: u32::from(normalize),
312 filter_type: 0, padding: 0,
314 sigma: 0.0,
315 };
316 let params_bytes = bytemuck::bytes_of(¶ms);
317 let params_buffer = utils::create_uniform_buffer(device, params_bytes)?;
318
319 let compiler = ShaderCompiler::new(device);
321 let bind_group = compiler.create_bind_group(
322 "Filter Bind Group",
323 layout,
324 &[
325 wgpu::BindGroupEntry {
326 binding: 0,
327 resource: input_buffer.buffer().as_entire_binding(),
328 },
329 wgpu::BindGroupEntry {
330 binding: 1,
331 resource: output_buffer.buffer().as_entire_binding(),
332 },
333 wgpu::BindGroupEntry {
334 binding: 2,
335 resource: params_buffer.buffer().as_entire_binding(),
336 },
337 wgpu::BindGroupEntry {
338 binding: 3,
339 resource: kernel_buffer.buffer().as_entire_binding(),
340 },
341 ],
342 );
343
344 Self::dispatch_compute(device, pipeline, &bind_group, width, height)?;
346
347 let readback_buffer = utils::create_readback_buffer(device, output.len() as u64)?;
349 let mut encoder = device
350 .device()
351 .create_command_encoder(&wgpu::CommandEncoderDescriptor {
352 label: Some("Filter Copy Encoder"),
353 });
354
355 output_buffer.copy_to(&mut encoder, &readback_buffer, 0, 0, output.len() as u64)?;
356
357 device.queue().submit(Some(encoder.finish()));
358 device.wait();
359
360 let result = readback_buffer.read(device, 0, output.len() as u64)?;
361 output.copy_from_slice(&result);
362
363 Ok(())
364 }
365
366 fn dispatch_compute(
367 device: &GpuDevice,
368 pipeline: &ComputePipeline,
369 bind_group: &BindGroup,
370 width: u32,
371 height: u32,
372 ) -> Result<()> {
373 let mut encoder = device
374 .device()
375 .create_command_encoder(&wgpu::CommandEncoderDescriptor {
376 label: Some("Filter Compute Encoder"),
377 });
378
379 {
380 let mut compute_pass = encoder.begin_compute_pass(&wgpu::ComputePassDescriptor {
381 label: Some("Filter Compute Pass"),
382 timestamp_writes: None,
383 });
384
385 compute_pass.set_pipeline(pipeline);
386 compute_pass.set_bind_group(0, bind_group, &[]);
387
388 let (dispatch_x, dispatch_y) = utils::calculate_dispatch_size(width, height, (16, 16));
389 compute_pass.dispatch_workgroups(dispatch_x, dispatch_y, 1);
390 }
391
392 device.queue().submit(Some(encoder.finish()));
393 Ok(())
394 }
395
396 fn calculate_kernel_size(sigma: f32) -> u32 {
397 let radius = (3.0 * sigma).ceil() as u32;
399 2 * radius + 1
400 }
401
402 fn get_bind_group_layout(device: &GpuDevice) -> Result<&'static BindGroupLayout> {
403 static LAYOUT: OnceCell<BindGroupLayout> = OnceCell::new();
404
405 Ok(LAYOUT.get_or_init(|| {
406 let compiler = ShaderCompiler::new(device);
407 let entries = BindGroupLayoutBuilder::new()
408 .add_storage_buffer_read_only(0) .add_storage_buffer(1) .add_uniform_buffer(2) .build();
412
413 compiler.create_bind_group_layout("Filter Bind Group Layout", &entries)
414 }))
415 }
416
417 fn get_bind_group_layout_with_kernel(device: &GpuDevice) -> Result<&'static BindGroupLayout> {
418 static LAYOUT: OnceCell<BindGroupLayout> = OnceCell::new();
419
420 Ok(LAYOUT.get_or_init(|| {
421 let compiler = ShaderCompiler::new(device);
422 let entries = BindGroupLayoutBuilder::new()
423 .add_storage_buffer_read_only(0) .add_storage_buffer(1) .add_uniform_buffer(2) .add_storage_buffer_read_only(3) .build();
428
429 compiler.create_bind_group_layout("Filter Bind Group Layout (with kernel)", &entries)
430 }))
431 }
432
433 fn init_pipeline(
434 device: &GpuDevice,
435 name: &str,
436 entry_point: &str,
437 layout_fn: fn(&GpuDevice) -> Result<&'static BindGroupLayout>,
438 ) -> std::result::Result<ComputePipeline, String> {
439 let compiler = ShaderCompiler::new(device);
440 let shader = compiler
441 .compile(
442 "Filter Shader",
443 ShaderSource::Embedded(crate::shader::embedded::FILTER_SHADER),
444 )
445 .map_err(|e| format!("Failed to compile filter shader: {e}"))?;
446
447 let layout =
448 layout_fn(device).map_err(|e| format!("Failed to create bind group layout: {e}"))?;
449
450 compiler
451 .create_pipeline(name, &shader, entry_point, layout)
452 .map_err(|e| format!("Failed to create pipeline: {e}"))
453 }
454
455 fn get_gaussian_pipeline(device: &GpuDevice) -> Result<&'static ComputePipeline> {
456 static PIPELINE: OnceCell<std::result::Result<ComputePipeline, String>> = OnceCell::new();
457
458 PIPELINE
459 .get_or_init(|| {
460 FilterOperation::init_pipeline(
461 device,
462 "Gaussian Blur Pipeline",
463 "convolve_main",
464 Self::get_bind_group_layout,
465 )
466 })
467 .as_ref()
468 .map_err(|e| crate::GpuError::PipelineCreation(e.clone()))
469 }
470
471 fn get_sharpen_pipeline(device: &GpuDevice) -> Result<&'static ComputePipeline> {
472 static PIPELINE: OnceCell<std::result::Result<ComputePipeline, String>> = OnceCell::new();
473
474 PIPELINE
475 .get_or_init(|| {
476 FilterOperation::init_pipeline(
477 device,
478 "Sharpen Pipeline",
479 "unsharp_mask",
480 Self::get_bind_group_layout,
481 )
482 })
483 .as_ref()
484 .map_err(|e| crate::GpuError::PipelineCreation(e.clone()))
485 }
486
487 fn get_edge_detect_pipeline(device: &GpuDevice) -> Result<&'static ComputePipeline> {
488 static PIPELINE: OnceCell<std::result::Result<ComputePipeline, String>> = OnceCell::new();
489
490 PIPELINE
491 .get_or_init(|| {
492 FilterOperation::init_pipeline(
493 device,
494 "Edge Detect Pipeline",
495 "edge_detect",
496 Self::get_bind_group_layout,
497 )
498 })
499 .as_ref()
500 .map_err(|e| crate::GpuError::PipelineCreation(e.clone()))
501 }
502
503 fn get_convolve_pipeline(device: &GpuDevice) -> Result<&'static ComputePipeline> {
504 static PIPELINE: OnceCell<std::result::Result<ComputePipeline, String>> = OnceCell::new();
505
506 PIPELINE
507 .get_or_init(|| {
508 FilterOperation::init_pipeline(
509 device,
510 "Convolve Pipeline",
511 "convolve_main",
512 Self::get_bind_group_layout_with_kernel,
513 )
514 })
515 .as_ref()
516 .map_err(|e| crate::GpuError::PipelineCreation(e.clone()))
517 }
518}
519
520#[must_use]
529pub fn gaussian_kernel_1d(sigma: f32) -> Vec<f32> {
530 if sigma <= 0.0 {
531 return vec![1.0_f32];
532 }
533 let radius = (3.0 * sigma).ceil() as usize;
534 let len = 2 * radius + 1;
535 let mut kernel = Vec::with_capacity(len);
536 let two_sigma_sq = 2.0 * sigma * sigma;
537 let mut sum = 0.0_f32;
538 for i in 0..len {
539 let x = i as f32 - radius as f32;
540 let v = (-x * x / two_sigma_sq).exp();
541 kernel.push(v);
542 sum += v;
543 }
544 for k in &mut kernel {
545 *k /= sum;
546 }
547 kernel
548}
549
550pub fn gaussian_blur_separable(
563 input: &[u8],
564 output: &mut [u8],
565 width: u32,
566 height: u32,
567 sigma: f32,
568) -> crate::Result<()> {
569 utils::validate_dimensions(width, height)?;
570 utils::validate_buffer_size(input, width, height, 4)?;
571 utils::validate_buffer_size(output, width, height, 4)?;
572
573 let w = width as usize;
574 let h = height as usize;
575 let kernel = gaussian_kernel_1d(sigma);
576 let radius = kernel.len() / 2;
577
578 let mut h_pass = vec![0.0_f32; w * h * 4];
580 for row in 0..h {
581 for col in 0..w {
582 let mut acc = [0.0_f32; 4];
583 let mut wsum = 0.0_f32;
584 for (ki, &kw) in kernel.iter().enumerate() {
585 let sc = col as isize + ki as isize - radius as isize;
586 if sc < 0 || sc >= w as isize {
587 continue;
588 }
589 let src = (row * w + sc as usize) * 4;
590 for c in 0..4 {
591 acc[c] += kw * input[src + c] as f32;
592 }
593 wsum += kw;
594 }
595 let dst = (row * w + col) * 4;
596 let inv = if wsum > 0.0 { 1.0 / wsum } else { 1.0 };
597 for c in 0..4 {
598 h_pass[dst + c] = acc[c] * inv;
599 }
600 }
601 }
602
603 for row in 0..h {
605 for col in 0..w {
606 let mut acc = [0.0_f32; 4];
607 let mut wsum = 0.0_f32;
608 for (ki, &kw) in kernel.iter().enumerate() {
609 let sr = row as isize + ki as isize - radius as isize;
610 if sr < 0 || sr >= h as isize {
611 continue;
612 }
613 let src = (sr as usize * w + col) * 4;
614 for c in 0..4 {
615 acc[c] += kw * h_pass[src + c];
616 }
617 wsum += kw;
618 }
619 let dst = (row * w + col) * 4;
620 let inv = if wsum > 0.0 { 1.0 / wsum } else { 1.0 };
621 for c in 0..4 {
622 output[dst + c] = (acc[c] * inv).round().clamp(0.0, 255.0) as u8;
623 }
624 }
625 }
626
627 Ok(())
628}
629
630use rayon::prelude::*;
631
632pub fn gaussian_blur_separable_parallel(
648 input: &[u8],
649 output: &mut [u8],
650 width: u32,
651 height: u32,
652 sigma: f32,
653) -> crate::Result<()> {
654 utils::validate_dimensions(width, height)?;
655 utils::validate_buffer_size(input, width, height, 4)?;
656 utils::validate_buffer_size(output, width, height, 4)?;
657
658 let w = width as usize;
659 let h = height as usize;
660 let kernel = gaussian_kernel_1d(sigma);
661 let radius = kernel.len() / 2;
662
663 let mut h_pass = vec![0.0_f32; w * h * 4];
665 h_pass
666 .par_chunks_exact_mut(w * 4)
667 .enumerate()
668 .for_each(|(row, row_out)| {
669 for col in 0..w {
670 let mut acc = [0.0_f32; 4];
671 let mut wsum = 0.0_f32;
672 for (ki, &kw) in kernel.iter().enumerate() {
673 let sc = col as isize + ki as isize - radius as isize;
674 if sc < 0 || sc >= w as isize {
675 continue;
676 }
677 let src = (row * w + sc as usize) * 4;
678 for c in 0..4 {
679 acc[c] += kw * input[src + c] as f32;
680 }
681 wsum += kw;
682 }
683 let inv = if wsum > 0.0 { 1.0 / wsum } else { 1.0 };
684 let dst = col * 4;
685 for c in 0..4 {
686 row_out[dst + c] = acc[c] * inv;
687 }
688 }
689 });
690
691 output
693 .par_chunks_exact_mut(4)
694 .enumerate()
695 .for_each(|(px_idx, px_out)| {
696 let row = px_idx / w;
697 let col = px_idx % w;
698 let mut acc = [0.0_f32; 4];
699 let mut wsum = 0.0_f32;
700 for (ki, &kw) in kernel.iter().enumerate() {
701 let sr = row as isize + ki as isize - radius as isize;
702 if sr < 0 || sr >= h as isize {
703 continue;
704 }
705 let src = (sr as usize * w + col) * 4;
706 for c in 0..4 {
707 acc[c] += kw * h_pass[src + c];
708 }
709 wsum += kw;
710 }
711 let inv = if wsum > 0.0 { 1.0 / wsum } else { 1.0 };
712 for c in 0..4 {
713 px_out[c] = (acc[c] * inv).round().clamp(0.0, 255.0) as u8;
714 }
715 });
716
717 Ok(())
718}
719
720pub fn box_blur(
734 data: &[u8],
735 width: u32,
736 height: u32,
737 channels: u32,
738 radius: u32,
739) -> crate::Result<Vec<u8>> {
740 let w = width as usize;
741 let h = height as usize;
742 let ch = channels as usize;
743 let expected = w * h * ch;
744 if data.len() != expected {
745 return Err(crate::GpuError::InvalidBufferSize {
746 expected,
747 actual: data.len(),
748 });
749 }
750 if w == 0 || h == 0 {
751 return Ok(data.to_vec());
752 }
753
754 let r = radius as isize;
755
756 let mut h_pass = vec![0u32; w * h * ch];
760 for row in 0..h {
761 for c in 0..ch {
762 let right0 = r.min(w as isize - 1) as usize;
764 let mut window_sum: u32 = 0;
765 for kc in 0..=right0 {
766 window_sum += u32::from(data[(row * w + kc) * ch + c]);
767 }
768
769 for col in 0..w {
770 let left = (col as isize - r).max(0) as usize;
772 let right = (col as isize + r).min(w as isize - 1) as usize;
773
774 if col > 0 {
775 let prev_left = ((col as isize - 1) - r).max(0) as usize;
777 let prev_right = ((col as isize - 1) + r).min(w as isize - 1) as usize;
778 if left > prev_left {
780 window_sum -= u32::from(data[(row * w + prev_left) * ch + c]);
781 }
782 if right > prev_right {
784 window_sum += u32::from(data[(row * w + right) * ch + c]);
785 }
786 }
787
788 let window_len = (right - left + 1) as u32;
789 h_pass[(row * w + col) * ch + c] = (window_sum + window_len / 2) / window_len;
791 }
792 }
793 }
794
795 let mut output = vec![0u8; expected];
798 for col in 0..w {
799 for c in 0..ch {
800 let bot0 = r.min(h as isize - 1) as usize;
802 let mut window_sum: u32 = 0;
803 for kr in 0..=bot0 {
804 window_sum += h_pass[(kr * w + col) * ch + c];
805 }
806
807 for row in 0..h {
808 let top = (row as isize - r).max(0) as usize;
809 let bot = (row as isize + r).min(h as isize - 1) as usize;
810
811 if row > 0 {
812 let prev_top = ((row as isize - 1) - r).max(0) as usize;
813 let prev_bot = ((row as isize - 1) + r).min(h as isize - 1) as usize;
814 if top > prev_top {
815 window_sum -= h_pass[(prev_top * w + col) * ch + c];
816 }
817 if bot > prev_bot {
818 window_sum += h_pass[(bot * w + col) * ch + c];
819 }
820 }
821
822 let window_len = (bot - top + 1) as u32;
823 let avg = (window_sum + window_len / 2) / window_len;
824 output[(row * w + col) * ch + c] = avg.clamp(0, 255) as u8;
825 }
826 }
827 }
828
829 Ok(output)
830}
831
832pub fn median_filter(
844 data: &[u8],
845 width: u32,
846 height: u32,
847 channels: u32,
848 radius: u32,
849) -> crate::Result<Vec<u8>> {
850 let w = width as usize;
851 let h = height as usize;
852 let ch = channels as usize;
853 let expected = w * h * ch;
854 if data.len() != expected {
855 return Err(crate::GpuError::InvalidBufferSize {
856 expected,
857 actual: data.len(),
858 });
859 }
860 if w == 0 || h == 0 {
861 return Ok(data.to_vec());
862 }
863
864 let r = radius as isize;
865 let window_len = ((2 * r + 1) * (2 * r + 1)) as usize;
866 let mut output = vec![0u8; expected];
867
868 for row in 0..h {
869 for col in 0..w {
870 for c in 0..ch {
871 let mut window: Vec<u8> = Vec::with_capacity(window_len);
872 for dy in -r..=r {
873 for dx in -r..=r {
874 let sr = (row as isize + dy).clamp(0, h as isize - 1) as usize;
875 let sc = (col as isize + dx).clamp(0, w as isize - 1) as usize;
876 window.push(data[(sr * w + sc) * ch + c]);
877 }
878 }
879 window.sort_unstable();
880 output[(row * w + col) * ch + c] = window[window.len() / 2];
881 }
882 }
883 }
884
885 Ok(output)
886}
887
888pub fn bilateral_filter(
902 data: &[u8],
903 width: u32,
904 height: u32,
905 channels: u32,
906 sigma_spatial: f32,
907 sigma_range: f32,
908) -> crate::Result<Vec<u8>> {
909 if channels != 4 {
910 return Err(crate::GpuError::NotSupported(format!(
911 "bilateral_filter requires channels == 4, got {channels}"
912 )));
913 }
914 utils::validate_buffer_size(data, width, height, 4)?;
915 let mut output = vec![0u8; data.len()];
916 super::DenoiseOperation::denoise_bilateral_cpu(
917 data,
918 &mut output,
919 width,
920 height,
921 sigma_spatial,
922 sigma_range,
923 )?;
924 Ok(output)
925}
926
927#[must_use]
932pub fn max_channel_diff(a: &[u8], b: &[u8]) -> u32 {
933 a.iter()
934 .zip(b.iter())
935 .map(|(&x, &y)| (x as i32 - y as i32).unsigned_abs())
936 .max()
937 .unwrap_or(0)
938}
939
940#[cfg(test)]
941mod tests {
942 use super::*;
943
944 #[test]
945 fn test_kernel_sums_to_one() {
946 let k = gaussian_kernel_1d(1.0);
947 let sum: f32 = k.iter().sum();
948 assert!((sum - 1.0).abs() < 1e-5, "kernel sum = {sum}");
949 }
950
951 #[test]
952 fn test_kernel_is_symmetric() {
953 let k = gaussian_kernel_1d(2.0);
954 let n = k.len();
955 for i in 0..n / 2 {
956 assert!(
957 (k[i] - k[n - 1 - i]).abs() < 1e-6,
958 "asymmetric at index {i}: {} vs {}",
959 k[i],
960 k[n - 1 - i]
961 );
962 }
963 }
964
965 #[test]
966 fn test_kernel_center_is_largest() {
967 let k = gaussian_kernel_1d(1.5);
968 let center = k[k.len() / 2];
969 for &v in &k {
970 assert!(center >= v, "center {center} not >= {v}");
971 }
972 }
973
974 #[test]
975 fn test_kernel_zero_sigma_returns_identity() {
976 let k = gaussian_kernel_1d(0.0);
977 assert_eq!(k.len(), 1);
978 assert!((k[0] - 1.0).abs() < 1e-6);
979 }
980
981 #[test]
982 fn test_kernel_negative_sigma_returns_identity() {
983 let k = gaussian_kernel_1d(-1.0);
984 assert_eq!(k.len(), 1);
985 assert!((k[0] - 1.0).abs() < 1e-6);
986 }
987
988 #[test]
989 fn test_blur_uniform_image_unchanged() {
990 let w = 8u32;
991 let h = 8u32;
992 let input: Vec<u8> = (0..(w * h * 4) as usize)
993 .map(|i| if i % 4 == 3 { 255 } else { 128 })
994 .collect();
995 let mut output = vec![0u8; (w * h * 4) as usize];
996 gaussian_blur_separable(&input, &mut output, w, h, 1.5).expect("blur should succeed");
997 for (i, (&inp, &out)) in input.iter().zip(output.iter()).enumerate() {
998 assert!(
999 (inp as i32 - out as i32).unsigned_abs() <= 1,
1000 "pixel {i}: input={inp} output={out}"
1001 );
1002 }
1003 }
1004
1005 #[test]
1006 fn test_blur_reduces_contrast() {
1007 let w = 4u32;
1008 let h = 4u32;
1009 let mut input = vec![0u8; (w * h * 4) as usize];
1010 for row in 0..h as usize {
1011 for col in 0..w as usize {
1012 let v = if (row + col) % 2 == 0 { 255u8 } else { 0u8 };
1013 let base = (row * w as usize + col) * 4;
1014 input[base] = v;
1015 input[base + 1] = v;
1016 input[base + 2] = v;
1017 input[base + 3] = 255;
1018 }
1019 }
1020 let mut output = vec![0u8; (w * h * 4) as usize];
1021 gaussian_blur_separable(&input, &mut output, w, h, 1.0).expect("blur should succeed");
1022 let max_rgb = output
1023 .chunks(4)
1024 .flat_map(|px| &px[..3])
1025 .copied()
1026 .max()
1027 .unwrap_or(0);
1028 assert!(
1029 max_rgb < 255,
1030 "max_rgb after blur = {max_rgb}; expected < 255"
1031 );
1032 }
1033
1034 #[test]
1035 fn test_blur_size_mismatch_returns_error() {
1036 let w = 4u32;
1037 let h = 4u32;
1038 let input = vec![0u8; (w * h * 4) as usize];
1039 let mut output = vec![0u8; 10];
1040 let result = gaussian_blur_separable(&input, &mut output, w, h, 1.0);
1041 assert!(result.is_err());
1042 }
1043
1044 #[test]
1045 fn test_blur_single_pixel_passthrough() {
1046 let input = vec![100u8, 150u8, 200u8, 255u8];
1047 let mut output = vec![0u8; 4];
1048 gaussian_blur_separable(&input, &mut output, 1, 1, 1.0).expect("blur should succeed");
1049 assert_eq!(output[0], 100);
1050 assert_eq!(output[1], 150);
1051 assert_eq!(output[2], 200);
1052 assert_eq!(output[3], 255);
1053 }
1054
1055 #[test]
1058 fn test_parallel_blur_matches_serial_uniform_image() {
1059 let w = 16u32;
1060 let h = 16u32;
1061 let input: Vec<u8> = vec![128u8; (w * h * 4) as usize];
1062 let mut serial = vec![0u8; (w * h * 4) as usize];
1063 let mut parallel = vec![0u8; (w * h * 4) as usize];
1064 gaussian_blur_separable(&input, &mut serial, w, h, 1.5).expect("serial blur");
1065 gaussian_blur_separable_parallel(&input, &mut parallel, w, h, 1.5).expect("parallel blur");
1066 assert_eq!(
1067 max_channel_diff(&serial, ¶llel),
1068 0,
1069 "serial and parallel must agree on uniform image"
1070 );
1071 }
1072
1073 #[test]
1074 fn test_parallel_blur_matches_serial_random_image() {
1075 let w = 8u32;
1076 let h = 8u32;
1077 let input: Vec<u8> = (0..(w * h * 4) as usize)
1078 .map(|i| ((i * 37 + 13) % 256) as u8)
1079 .collect();
1080 let mut serial = vec![0u8; (w * h * 4) as usize];
1081 let mut parallel = vec![0u8; (w * h * 4) as usize];
1082 gaussian_blur_separable(&input, &mut serial, w, h, 1.0).expect("serial blur");
1083 gaussian_blur_separable_parallel(&input, &mut parallel, w, h, 1.0).expect("parallel blur");
1084 let max_diff = max_channel_diff(&serial, ¶llel);
1085 assert_eq!(max_diff, 0, "serial and parallel outputs must be identical");
1086 }
1087
1088 #[test]
1089 fn test_parallel_blur_single_pixel_passthrough() {
1090 let input = vec![77u8, 88, 99, 255];
1091 let mut output = vec![0u8; 4];
1092 gaussian_blur_separable_parallel(&input, &mut output, 1, 1, 2.0)
1093 .expect("single pixel parallel blur");
1094 assert_eq!(output[0], 77);
1095 assert_eq!(output[1], 88);
1096 assert_eq!(output[2], 99);
1097 assert_eq!(output[3], 255);
1098 }
1099
1100 #[test]
1101 fn test_parallel_blur_size_mismatch_returns_error() {
1102 let input = vec![0u8; 4 * 4 * 4];
1103 let mut output = vec![0u8; 5]; let res = gaussian_blur_separable_parallel(&input, &mut output, 4, 4, 1.0);
1105 assert!(res.is_err());
1106 }
1107
1108 #[test]
1109 fn test_parallel_blur_reduces_contrast() {
1110 let w = 8u32;
1111 let h = 8u32;
1112 let mut input = vec![0u8; (w * h * 4) as usize];
1113 for row in 0..h as usize {
1114 for col in 0..w as usize {
1115 let v = if (row + col) % 2 == 0 { 255u8 } else { 0u8 };
1116 let base = (row * w as usize + col) * 4;
1117 input[base] = v;
1118 input[base + 1] = v;
1119 input[base + 2] = v;
1120 input[base + 3] = 255;
1121 }
1122 }
1123 let mut output = vec![0u8; (w * h * 4) as usize];
1124 gaussian_blur_separable_parallel(&input, &mut output, w, h, 1.5)
1125 .expect("parallel contrast blur");
1126 let max_rgb = output
1127 .chunks(4)
1128 .flat_map(|px| &px[..3])
1129 .copied()
1130 .max()
1131 .unwrap_or(0);
1132 assert!(
1133 max_rgb < 255,
1134 "parallel blur should reduce max brightness; got {max_rgb}"
1135 );
1136 }
1137
1138 #[test]
1139 fn test_parallel_blur_large_sigma_heavy_smoothing() {
1140 let w = 16u32;
1141 let h = 16u32;
1142 let input: Vec<u8> = (0..(w * h) as usize)
1144 .flat_map(|i| {
1145 let row = i / w as usize;
1146 let col = i % w as usize;
1147 let v = if (row + col) % 2 == 0 { 255u8 } else { 0u8 };
1148 [v, v, v, 255u8]
1149 })
1150 .collect();
1151 let mut out_small = vec![0u8; (w * h * 4) as usize];
1152 let mut out_large = vec![0u8; (w * h * 4) as usize];
1153 gaussian_blur_separable_parallel(&input, &mut out_small, w, h, 0.5).expect("small sigma");
1154 gaussian_blur_separable_parallel(&input, &mut out_large, w, h, 3.0).expect("large sigma");
1155
1156 let range_small: u32 = out_small
1157 .chunks(4)
1158 .map(|px| px[0] as u32)
1159 .max()
1160 .unwrap_or(0)
1161 - out_small
1162 .chunks(4)
1163 .map(|px| px[0] as u32)
1164 .min()
1165 .unwrap_or(0);
1166 let range_large: u32 = out_large
1167 .chunks(4)
1168 .map(|px| px[0] as u32)
1169 .max()
1170 .unwrap_or(0)
1171 - out_large
1172 .chunks(4)
1173 .map(|px| px[0] as u32)
1174 .min()
1175 .unwrap_or(0);
1176 assert!(
1177 range_large <= range_small,
1178 "larger sigma should produce smaller contrast range; small={range_small}, large={range_large}"
1179 );
1180 }
1181
1182 #[test]
1183 fn test_parallel_blur_wide_image() {
1184 let w = 32u32;
1185 let h = 4u32;
1186 let input: Vec<u8> = (0..(w * h * 4) as usize).map(|i| (i % 256) as u8).collect();
1187 let mut output = vec![0u8; (w * h * 4) as usize];
1188 gaussian_blur_separable_parallel(&input, &mut output, w, h, 1.0)
1189 .expect("wide image parallel blur");
1190 assert_eq!(output.len(), (w * h * 4) as usize);
1191 }
1192
1193 #[test]
1194 fn test_parallel_blur_tall_image() {
1195 let w = 4u32;
1196 let h = 32u32;
1197 let input: Vec<u8> = (0..(w * h * 4) as usize).map(|i| (i % 256) as u8).collect();
1198 let mut output = vec![0u8; (w * h * 4) as usize];
1199 gaussian_blur_separable_parallel(&input, &mut output, w, h, 1.0)
1200 .expect("tall image parallel blur");
1201 assert_eq!(output.len(), (w * h * 4) as usize);
1202 }
1203
1204 #[test]
1205 fn test_max_channel_diff_identical() {
1206 let a = vec![128u8; 16];
1207 let diff = max_channel_diff(&a, &a);
1208 assert_eq!(diff, 0);
1209 }
1210
1211 #[test]
1212 fn test_max_channel_diff_known_values() {
1213 let a = vec![100u8, 200, 50, 255];
1214 let b = vec![90u8, 210, 50, 255];
1215 let diff = max_channel_diff(&a, &b);
1216 assert_eq!(diff, 10);
1217 }
1218
1219 #[test]
1222 fn test_box_blur_uniform() {
1223 let w = 4u32;
1226 let h = 4u32;
1227 let ch = 3u32;
1228 let value: u8 = 128;
1229 let input = vec![value; (w * h * ch) as usize];
1230 let output = box_blur(&input, w, h, ch, 2).expect("box_blur should succeed");
1231 for (i, &v) in output.iter().enumerate() {
1232 assert!(
1233 (v as i32 - value as i32).abs() <= 1,
1234 "pixel byte {i}: expected {value}, got {v}"
1235 );
1236 }
1237 }
1238
1239 #[test]
1240 fn test_box_blur_spike() {
1241 let w = 7u32;
1245 let h = 7u32;
1246 let ch = 1u32;
1247 let mut input = vec![0u8; (w * h * ch) as usize];
1248 let cx = 3usize;
1249 let cy = 3usize;
1250 input[cy * w as usize + cx] = 255;
1251
1252 let output = box_blur(&input, w, h, ch, 1).expect("box_blur spike should succeed");
1253
1254 let centre = output[cy * w as usize + cx];
1256 assert!(
1257 centre < 255,
1258 "centre pixel should be reduced after box blur, got {centre}"
1259 );
1260
1261 let right = output[cy * w as usize + cx + 1];
1263 let below = output[(cy + 1) * w as usize + cx];
1264 assert!(
1265 right > 0 || below > 0,
1266 "neighbours should receive energy; right={right}, below={below}"
1267 );
1268 }
1269
1270 #[test]
1271 fn test_box_blur_size_mismatch_returns_error() {
1272 let result = box_blur(&[0u8; 10], 4, 4, 1, 1);
1274 assert!(result.is_err(), "expected error on size mismatch");
1275 }
1276
1277 #[test]
1280 fn test_median_removes_outlier() {
1281 let w = 5u32;
1285 let h = 5u32;
1286 let ch = 1u32;
1287 let mut input = vec![100u8; (w * h * ch) as usize];
1288 let cx = 2usize;
1289 let cy = 2usize;
1290 input[cy * w as usize + cx] = 255; let output = median_filter(&input, w, h, ch, 1).expect("median_filter should succeed");
1293
1294 let centre = output[cy * w as usize + cx];
1295 assert_eq!(
1296 centre, 100,
1297 "median should remove the outlier; centre={centre}"
1298 );
1299 }
1300
1301 #[test]
1302 fn test_median_uniform_image() {
1303 let w = 4u32;
1305 let h = 4u32;
1306 let ch = 4u32;
1307 let input = vec![77u8; (w * h * ch) as usize];
1308 let output = median_filter(&input, w, h, ch, 2).expect("median_filter uniform");
1309 assert!(output.iter().all(|&v| v == 77));
1310 }
1311
1312 #[test]
1313 fn test_median_size_mismatch_returns_error() {
1314 let result = median_filter(&[0u8; 5], 4, 4, 1, 1);
1315 assert!(result.is_err(), "expected error on size mismatch");
1316 }
1317
1318 #[test]
1321 fn test_bilateral_edge_preserving() {
1322 let w = 10u32;
1326 let h = 10u32;
1327 let mut input = vec![0u8; (w * h * 4) as usize];
1328 for row in 0..h as usize {
1329 for col in 0..w as usize {
1330 let v: u8 = if col >= 5 { 255 } else { 0 };
1331 let base = (row * w as usize + col) * 4;
1332 input[base] = v;
1333 input[base + 1] = v;
1334 input[base + 2] = v;
1335 input[base + 3] = 255;
1336 }
1337 }
1338
1339 let output =
1342 bilateral_filter(&input, w, h, 4, 2.0, 10.0).expect("bilateral_filter should succeed");
1343
1344 for row in 0..h as usize {
1346 let col = 1usize; let base = (row * w as usize + col) * 4;
1348 for c in 0..3 {
1349 assert!(
1350 output[base + c] < 64,
1351 "row={row} col={col} ch={c}: expected near 0, got {}",
1352 output[base + c]
1353 );
1354 }
1355 }
1356
1357 for row in 0..h as usize {
1359 let col = 8usize; let base = (row * w as usize + col) * 4;
1361 for c in 0..3 {
1362 assert!(
1363 output[base + c] > 191,
1364 "row={row} col={col} ch={c}: expected near 255, got {}",
1365 output[base + c]
1366 );
1367 }
1368 }
1369 }
1370
1371 #[test]
1372 fn test_bilateral_wrong_channels_returns_error() {
1373 let result = bilateral_filter(&[0u8; 9], 3, 3, 1, 2.0, 30.0);
1374 assert!(result.is_err(), "bilateral requires channels == 4");
1375 }
1376}