1use crate::{
4 shader::{BindGroupLayoutBuilder, ShaderCompiler, ShaderSource},
5 GpuDevice, GpuError, Result,
6};
7use bytemuck::{Pod, Zeroable};
8use once_cell::sync::OnceCell;
9use wgpu::{BindGroup, BindGroupLayout, ComputePipeline};
10
11use super::utils;
12
13#[repr(C)]
14#[derive(Copy, Clone, Pod, Zeroable)]
15struct FilterParams {
16 width: u32,
17 height: u32,
18 stride: u32,
19 kernel_size: u32,
20 normalize: u32,
21 filter_type: u32,
22 padding: u32,
23 sigma: f32,
24}
25
26pub struct FilterOperation;
28
29impl FilterOperation {
30 #[allow(clippy::too_many_arguments)]
45 pub fn gaussian_blur(
46 device: &GpuDevice,
47 input: &[u8],
48 output: &mut [u8],
49 width: u32,
50 height: u32,
51 sigma: f32,
52 ) -> Result<()> {
53 utils::validate_dimensions(width, height)?;
54 utils::validate_buffer_size(input, width, height, 4)?;
55 utils::validate_buffer_size(output, width, height, 4)?;
56
57 let kernel_size = Self::calculate_kernel_size(sigma);
58 let pipeline = Self::get_gaussian_pipeline(device)?;
59 let layout = Self::get_bind_group_layout(device)?;
60
61 Self::execute_filter(
62 device,
63 pipeline,
64 layout,
65 input,
66 output,
67 width,
68 height,
69 kernel_size,
70 1, sigma,
72 )
73 }
74
75 #[allow(clippy::too_many_arguments)]
90 pub fn sharpen(
91 device: &GpuDevice,
92 input: &[u8],
93 output: &mut [u8],
94 width: u32,
95 height: u32,
96 amount: f32,
97 ) -> Result<()> {
98 utils::validate_dimensions(width, height)?;
99 utils::validate_buffer_size(input, width, height, 4)?;
100 utils::validate_buffer_size(output, width, height, 4)?;
101
102 let pipeline = Self::get_sharpen_pipeline(device)?;
103 let layout = Self::get_bind_group_layout(device)?;
104
105 Self::execute_filter(
106 device, pipeline, layout, input, output, width, height,
107 5, 2, amount,
110 )
111 }
112
113 pub fn edge_detect(
127 device: &GpuDevice,
128 input: &[u8],
129 output: &mut [u8],
130 width: u32,
131 height: u32,
132 ) -> Result<()> {
133 utils::validate_dimensions(width, height)?;
134 utils::validate_buffer_size(input, width, height, 4)?;
135 utils::validate_buffer_size(output, width, height, 4)?;
136
137 let pipeline = Self::get_edge_detect_pipeline(device)?;
138 let layout = Self::get_bind_group_layout(device)?;
139
140 Self::execute_filter(
141 device, pipeline, layout, input, output, width, height, 3, 3, 0.0,
144 )
145 }
146
147 #[allow(clippy::too_many_arguments)]
163 pub fn convolve(
164 device: &GpuDevice,
165 input: &[u8],
166 output: &mut [u8],
167 width: u32,
168 height: u32,
169 kernel: &[f32],
170 normalize: bool,
171 ) -> Result<()> {
172 utils::validate_dimensions(width, height)?;
173 utils::validate_buffer_size(input, width, height, 4)?;
174 utils::validate_buffer_size(output, width, height, 4)?;
175
176 let kernel_size = (kernel.len() as f32).sqrt() as u32;
177 if kernel_size * kernel_size != kernel.len() as u32 {
178 return Err(GpuError::Internal("Kernel must be square".to_string()));
179 }
180 if kernel_size % 2 == 0 {
181 return Err(GpuError::Internal("Kernel size must be odd".to_string()));
182 }
183
184 let pipeline = Self::get_convolve_pipeline(device)?;
185 let layout = Self::get_bind_group_layout_with_kernel(device)?;
186
187 Self::execute_convolve(
188 device,
189 pipeline,
190 layout,
191 input,
192 output,
193 width,
194 height,
195 kernel,
196 kernel_size,
197 normalize,
198 )
199 }
200
201 #[allow(clippy::too_many_arguments)]
202 fn execute_filter(
203 device: &GpuDevice,
204 pipeline: &ComputePipeline,
205 layout: &BindGroupLayout,
206 input: &[u8],
207 output: &mut [u8],
208 width: u32,
209 height: u32,
210 kernel_size: u32,
211 filter_type: u32,
212 sigma: f32,
213 ) -> Result<()> {
214 let input_buffer = utils::create_storage_buffer(device, input.len() as u64)?;
216 let output_buffer = utils::create_storage_buffer(device, output.len() as u64)?;
217
218 device.queue().write_buffer(input_buffer.buffer(), 0, input);
220
221 let params = FilterParams {
223 width,
224 height,
225 stride: width,
226 kernel_size,
227 normalize: 1,
228 filter_type,
229 padding: 0,
230 sigma,
231 };
232 let params_bytes = bytemuck::bytes_of(¶ms);
233 let params_buffer = utils::create_uniform_buffer(device, params_bytes)?;
234
235 let compiler = ShaderCompiler::new(device);
237 let bind_group = compiler.create_bind_group(
238 "Filter Bind Group",
239 layout,
240 &[
241 wgpu::BindGroupEntry {
242 binding: 0,
243 resource: input_buffer.buffer().as_entire_binding(),
244 },
245 wgpu::BindGroupEntry {
246 binding: 1,
247 resource: output_buffer.buffer().as_entire_binding(),
248 },
249 wgpu::BindGroupEntry {
250 binding: 2,
251 resource: params_buffer.buffer().as_entire_binding(),
252 },
253 ],
254 );
255
256 Self::dispatch_compute(device, pipeline, &bind_group, width, height)?;
258
259 let readback_buffer = utils::create_readback_buffer(device, output.len() as u64)?;
261 let mut encoder = device
262 .device()
263 .create_command_encoder(&wgpu::CommandEncoderDescriptor {
264 label: Some("Filter Copy Encoder"),
265 });
266
267 output_buffer.copy_to(&mut encoder, &readback_buffer, 0, 0, output.len() as u64)?;
268
269 device.queue().submit(Some(encoder.finish()));
270 device.wait();
271
272 let result = readback_buffer.read(device, 0, output.len() as u64)?;
273 output.copy_from_slice(&result);
274
275 Ok(())
276 }
277
278 #[allow(clippy::too_many_arguments)]
279 fn execute_convolve(
280 device: &GpuDevice,
281 pipeline: &ComputePipeline,
282 layout: &BindGroupLayout,
283 input: &[u8],
284 output: &mut [u8],
285 width: u32,
286 height: u32,
287 kernel: &[f32],
288 kernel_size: u32,
289 normalize: bool,
290 ) -> Result<()> {
291 let input_buffer = utils::create_storage_buffer(device, input.len() as u64)?;
293 let output_buffer = utils::create_storage_buffer(device, output.len() as u64)?;
294
295 device.queue().write_buffer(input_buffer.buffer(), 0, input);
297
298 let kernel_bytes = bytemuck::cast_slice(kernel);
300 let kernel_buffer = utils::create_storage_buffer(device, kernel_bytes.len() as u64)?;
301 device
302 .queue()
303 .write_buffer(kernel_buffer.buffer(), 0, kernel_bytes);
304
305 let params = FilterParams {
307 width,
308 height,
309 stride: width,
310 kernel_size,
311 normalize: u32::from(normalize),
312 filter_type: 0, padding: 0,
314 sigma: 0.0,
315 };
316 let params_bytes = bytemuck::bytes_of(¶ms);
317 let params_buffer = utils::create_uniform_buffer(device, params_bytes)?;
318
319 let compiler = ShaderCompiler::new(device);
321 let bind_group = compiler.create_bind_group(
322 "Filter Bind Group",
323 layout,
324 &[
325 wgpu::BindGroupEntry {
326 binding: 0,
327 resource: input_buffer.buffer().as_entire_binding(),
328 },
329 wgpu::BindGroupEntry {
330 binding: 1,
331 resource: output_buffer.buffer().as_entire_binding(),
332 },
333 wgpu::BindGroupEntry {
334 binding: 2,
335 resource: params_buffer.buffer().as_entire_binding(),
336 },
337 wgpu::BindGroupEntry {
338 binding: 3,
339 resource: kernel_buffer.buffer().as_entire_binding(),
340 },
341 ],
342 );
343
344 Self::dispatch_compute(device, pipeline, &bind_group, width, height)?;
346
347 let readback_buffer = utils::create_readback_buffer(device, output.len() as u64)?;
349 let mut encoder = device
350 .device()
351 .create_command_encoder(&wgpu::CommandEncoderDescriptor {
352 label: Some("Filter Copy Encoder"),
353 });
354
355 output_buffer.copy_to(&mut encoder, &readback_buffer, 0, 0, output.len() as u64)?;
356
357 device.queue().submit(Some(encoder.finish()));
358 device.wait();
359
360 let result = readback_buffer.read(device, 0, output.len() as u64)?;
361 output.copy_from_slice(&result);
362
363 Ok(())
364 }
365
366 fn dispatch_compute(
367 device: &GpuDevice,
368 pipeline: &ComputePipeline,
369 bind_group: &BindGroup,
370 width: u32,
371 height: u32,
372 ) -> Result<()> {
373 let mut encoder = device
374 .device()
375 .create_command_encoder(&wgpu::CommandEncoderDescriptor {
376 label: Some("Filter Compute Encoder"),
377 });
378
379 {
380 let mut compute_pass = encoder.begin_compute_pass(&wgpu::ComputePassDescriptor {
381 label: Some("Filter Compute Pass"),
382 timestamp_writes: None,
383 });
384
385 compute_pass.set_pipeline(pipeline);
386 compute_pass.set_bind_group(0, bind_group, &[]);
387
388 let (dispatch_x, dispatch_y) = utils::calculate_dispatch_size(width, height, (16, 16));
389 compute_pass.dispatch_workgroups(dispatch_x, dispatch_y, 1);
390 }
391
392 device.queue().submit(Some(encoder.finish()));
393 Ok(())
394 }
395
396 fn calculate_kernel_size(sigma: f32) -> u32 {
397 let radius = (3.0 * sigma).ceil() as u32;
399 2 * radius + 1
400 }
401
402 fn get_bind_group_layout(device: &GpuDevice) -> Result<&'static BindGroupLayout> {
403 static LAYOUT: OnceCell<BindGroupLayout> = OnceCell::new();
404
405 Ok(LAYOUT.get_or_init(|| {
406 let compiler = ShaderCompiler::new(device);
407 let entries = BindGroupLayoutBuilder::new()
408 .add_storage_buffer_read_only(0) .add_storage_buffer(1) .add_uniform_buffer(2) .build();
412
413 compiler.create_bind_group_layout("Filter Bind Group Layout", &entries)
414 }))
415 }
416
417 fn get_bind_group_layout_with_kernel(device: &GpuDevice) -> Result<&'static BindGroupLayout> {
418 static LAYOUT: OnceCell<BindGroupLayout> = OnceCell::new();
419
420 Ok(LAYOUT.get_or_init(|| {
421 let compiler = ShaderCompiler::new(device);
422 let entries = BindGroupLayoutBuilder::new()
423 .add_storage_buffer_read_only(0) .add_storage_buffer(1) .add_uniform_buffer(2) .add_storage_buffer_read_only(3) .build();
428
429 compiler.create_bind_group_layout("Filter Bind Group Layout (with kernel)", &entries)
430 }))
431 }
432
433 fn init_pipeline(
434 device: &GpuDevice,
435 name: &str,
436 entry_point: &str,
437 layout_fn: fn(&GpuDevice) -> Result<&'static BindGroupLayout>,
438 ) -> std::result::Result<ComputePipeline, String> {
439 let compiler = ShaderCompiler::new(device);
440 let shader = compiler
441 .compile(
442 "Filter Shader",
443 ShaderSource::Embedded(crate::shader::embedded::FILTER_SHADER),
444 )
445 .map_err(|e| format!("Failed to compile filter shader: {e}"))?;
446
447 let layout =
448 layout_fn(device).map_err(|e| format!("Failed to create bind group layout: {e}"))?;
449
450 compiler
451 .create_pipeline(name, &shader, entry_point, layout)
452 .map_err(|e| format!("Failed to create pipeline: {e}"))
453 }
454
455 fn get_gaussian_pipeline(device: &GpuDevice) -> Result<&'static ComputePipeline> {
456 static PIPELINE: OnceCell<std::result::Result<ComputePipeline, String>> = OnceCell::new();
457
458 PIPELINE
459 .get_or_init(|| {
460 FilterOperation::init_pipeline(
461 device,
462 "Gaussian Blur Pipeline",
463 "convolve_main",
464 Self::get_bind_group_layout,
465 )
466 })
467 .as_ref()
468 .map_err(|e| crate::GpuError::PipelineCreation(e.clone()))
469 }
470
471 fn get_sharpen_pipeline(device: &GpuDevice) -> Result<&'static ComputePipeline> {
472 static PIPELINE: OnceCell<std::result::Result<ComputePipeline, String>> = OnceCell::new();
473
474 PIPELINE
475 .get_or_init(|| {
476 FilterOperation::init_pipeline(
477 device,
478 "Sharpen Pipeline",
479 "unsharp_mask",
480 Self::get_bind_group_layout,
481 )
482 })
483 .as_ref()
484 .map_err(|e| crate::GpuError::PipelineCreation(e.clone()))
485 }
486
487 fn get_edge_detect_pipeline(device: &GpuDevice) -> Result<&'static ComputePipeline> {
488 static PIPELINE: OnceCell<std::result::Result<ComputePipeline, String>> = OnceCell::new();
489
490 PIPELINE
491 .get_or_init(|| {
492 FilterOperation::init_pipeline(
493 device,
494 "Edge Detect Pipeline",
495 "edge_detect",
496 Self::get_bind_group_layout,
497 )
498 })
499 .as_ref()
500 .map_err(|e| crate::GpuError::PipelineCreation(e.clone()))
501 }
502
503 fn get_convolve_pipeline(device: &GpuDevice) -> Result<&'static ComputePipeline> {
504 static PIPELINE: OnceCell<std::result::Result<ComputePipeline, String>> = OnceCell::new();
505
506 PIPELINE
507 .get_or_init(|| {
508 FilterOperation::init_pipeline(
509 device,
510 "Convolve Pipeline",
511 "convolve_main",
512 Self::get_bind_group_layout_with_kernel,
513 )
514 })
515 .as_ref()
516 .map_err(|e| crate::GpuError::PipelineCreation(e.clone()))
517 }
518}
519
520#[must_use]
529pub fn gaussian_kernel_1d(sigma: f32) -> Vec<f32> {
530 if sigma <= 0.0 {
531 return vec![1.0_f32];
532 }
533 let radius = (3.0 * sigma).ceil() as usize;
534 let len = 2 * radius + 1;
535 let mut kernel = Vec::with_capacity(len);
536 let two_sigma_sq = 2.0 * sigma * sigma;
537 let mut sum = 0.0_f32;
538 for i in 0..len {
539 let x = i as f32 - radius as f32;
540 let v = (-x * x / two_sigma_sq).exp();
541 kernel.push(v);
542 sum += v;
543 }
544 for k in &mut kernel {
545 *k /= sum;
546 }
547 kernel
548}
549
550pub fn gaussian_blur_separable(
563 input: &[u8],
564 output: &mut [u8],
565 width: u32,
566 height: u32,
567 sigma: f32,
568) -> crate::Result<()> {
569 utils::validate_dimensions(width, height)?;
570 utils::validate_buffer_size(input, width, height, 4)?;
571 utils::validate_buffer_size(output, width, height, 4)?;
572
573 let w = width as usize;
574 let h = height as usize;
575 let kernel = gaussian_kernel_1d(sigma);
576 let radius = kernel.len() / 2;
577
578 let mut h_pass = vec![0.0_f32; w * h * 4];
580 for row in 0..h {
581 for col in 0..w {
582 let mut acc = [0.0_f32; 4];
583 let mut wsum = 0.0_f32;
584 for (ki, &kw) in kernel.iter().enumerate() {
585 let sc = col as isize + ki as isize - radius as isize;
586 if sc < 0 || sc >= w as isize {
587 continue;
588 }
589 let src = (row * w + sc as usize) * 4;
590 for c in 0..4 {
591 acc[c] += kw * input[src + c] as f32;
592 }
593 wsum += kw;
594 }
595 let dst = (row * w + col) * 4;
596 let inv = if wsum > 0.0 { 1.0 / wsum } else { 1.0 };
597 for c in 0..4 {
598 h_pass[dst + c] = acc[c] * inv;
599 }
600 }
601 }
602
603 for row in 0..h {
605 for col in 0..w {
606 let mut acc = [0.0_f32; 4];
607 let mut wsum = 0.0_f32;
608 for (ki, &kw) in kernel.iter().enumerate() {
609 let sr = row as isize + ki as isize - radius as isize;
610 if sr < 0 || sr >= h as isize {
611 continue;
612 }
613 let src = (sr as usize * w + col) * 4;
614 for c in 0..4 {
615 acc[c] += kw * h_pass[src + c];
616 }
617 wsum += kw;
618 }
619 let dst = (row * w + col) * 4;
620 let inv = if wsum > 0.0 { 1.0 / wsum } else { 1.0 };
621 for c in 0..4 {
622 output[dst + c] = (acc[c] * inv).round().clamp(0.0, 255.0) as u8;
623 }
624 }
625 }
626
627 Ok(())
628}
629
630#[cfg(test)]
631mod tests {
632 use super::*;
633
634 #[test]
635 fn test_kernel_sums_to_one() {
636 let k = gaussian_kernel_1d(1.0);
637 let sum: f32 = k.iter().sum();
638 assert!((sum - 1.0).abs() < 1e-5, "kernel sum = {sum}");
639 }
640
641 #[test]
642 fn test_kernel_is_symmetric() {
643 let k = gaussian_kernel_1d(2.0);
644 let n = k.len();
645 for i in 0..n / 2 {
646 assert!(
647 (k[i] - k[n - 1 - i]).abs() < 1e-6,
648 "asymmetric at index {i}: {} vs {}",
649 k[i],
650 k[n - 1 - i]
651 );
652 }
653 }
654
655 #[test]
656 fn test_kernel_center_is_largest() {
657 let k = gaussian_kernel_1d(1.5);
658 let center = k[k.len() / 2];
659 for &v in &k {
660 assert!(center >= v, "center {center} not >= {v}");
661 }
662 }
663
664 #[test]
665 fn test_kernel_zero_sigma_returns_identity() {
666 let k = gaussian_kernel_1d(0.0);
667 assert_eq!(k.len(), 1);
668 assert!((k[0] - 1.0).abs() < 1e-6);
669 }
670
671 #[test]
672 fn test_kernel_negative_sigma_returns_identity() {
673 let k = gaussian_kernel_1d(-1.0);
674 assert_eq!(k.len(), 1);
675 assert!((k[0] - 1.0).abs() < 1e-6);
676 }
677
678 #[test]
679 fn test_blur_uniform_image_unchanged() {
680 let w = 8u32;
681 let h = 8u32;
682 let input: Vec<u8> = (0..(w * h * 4) as usize)
683 .map(|i| if i % 4 == 3 { 255 } else { 128 })
684 .collect();
685 let mut output = vec![0u8; (w * h * 4) as usize];
686 gaussian_blur_separable(&input, &mut output, w, h, 1.5).expect("blur should succeed");
687 for (i, (&inp, &out)) in input.iter().zip(output.iter()).enumerate() {
688 assert!(
689 (inp as i32 - out as i32).unsigned_abs() <= 1,
690 "pixel {i}: input={inp} output={out}"
691 );
692 }
693 }
694
695 #[test]
696 fn test_blur_reduces_contrast() {
697 let w = 4u32;
698 let h = 4u32;
699 let mut input = vec![0u8; (w * h * 4) as usize];
700 for row in 0..h as usize {
701 for col in 0..w as usize {
702 let v = if (row + col) % 2 == 0 { 255u8 } else { 0u8 };
703 let base = (row * w as usize + col) * 4;
704 input[base] = v;
705 input[base + 1] = v;
706 input[base + 2] = v;
707 input[base + 3] = 255;
708 }
709 }
710 let mut output = vec![0u8; (w * h * 4) as usize];
711 gaussian_blur_separable(&input, &mut output, w, h, 1.0).expect("blur should succeed");
712 let max_rgb = output
713 .chunks(4)
714 .flat_map(|px| &px[..3])
715 .copied()
716 .max()
717 .unwrap_or(0);
718 assert!(
719 max_rgb < 255,
720 "max_rgb after blur = {max_rgb}; expected < 255"
721 );
722 }
723
724 #[test]
725 fn test_blur_size_mismatch_returns_error() {
726 let w = 4u32;
727 let h = 4u32;
728 let input = vec![0u8; (w * h * 4) as usize];
729 let mut output = vec![0u8; 10];
730 let result = gaussian_blur_separable(&input, &mut output, w, h, 1.0);
731 assert!(result.is_err());
732 }
733
734 #[test]
735 fn test_blur_single_pixel_passthrough() {
736 let input = vec![100u8, 150u8, 200u8, 255u8];
737 let mut output = vec![0u8; 4];
738 gaussian_blur_separable(&input, &mut output, 1, 1, 1.0).expect("blur should succeed");
739 assert_eq!(output[0], 100);
740 assert_eq!(output[1], 150);
741 assert_eq!(output[2], 200);
742 assert_eq!(output[3], 255);
743 }
744}