image_convolution/
convolution.rs1use crate::gpu_device::*;
2use crate::{Image, Kernel, Real};
3
4pub async fn run(device: &GpuDevice, image: &Image, kernel: &Kernel) -> Image {
5 let crop = kernel.size - 1;
6 let mut output = Image {
7 data: Vec::new(),
8 width: image.width - crop,
9 height: image.height - crop,
10 };
11 let output_size = (output.size() * std::mem::size_of::<Real>() as u32) as u64;
12 let params = [image.width, kernel.size];
13 let params_data = bytemuck::cast_slice(¶ms);
14
15 let input_buffer = device.create_data_buffer("input", bytemuck::cast_slice(&image.data));
17 let result_buffer = device.create_buffer("result", output_size);
18 let kernel_buffer = device.create_data_buffer("kernel", bytemuck::cast_slice(&kernel.data));
19 let params_buffer = device.create_uniform_buffer("params", params_data);
20 let output_buffer = device.create_output_buffer("output", output_size);
21
22 let (bind_group, compute_pipeline) = device.create_compute_pipeline(
24 &[
25 (
26 &input_buffer,
27 4,
28 wgpu::BufferBindingType::Storage { read_only: true },
29 ),
30 (
31 &result_buffer,
32 4,
33 wgpu::BufferBindingType::Storage { read_only: false },
34 ),
35 (
36 &kernel_buffer,
37 4,
38 wgpu::BufferBindingType::Storage { read_only: true },
39 ),
40 (
41 ¶ms_buffer,
42 params_data.len() as u64,
43 wgpu::BufferBindingType::Uniform,
44 ),
45 ],
46 include_str!("convolution.wgsl"),
47 );
48
49 let mut encoder = device
51 .device
52 .create_command_encoder(&wgpu::CommandEncoderDescriptor { label: None });
53
54 {
55 let mut cpass = encoder.begin_compute_pass(&wgpu::ComputePassDescriptor { label: None });
56 cpass.set_bind_group(0, &bind_group, &[]);
57 cpass.set_pipeline(&compute_pipeline);
58 cpass.dispatch(output.width, output.height, 1);
59 }
60 encoder.copy_buffer_to_buffer(&result_buffer, 0, &output_buffer, 0, output_size);
62 device.queue.submit(Some(encoder.finish()));
63
64 let buffer_slice = output_buffer.slice(..);
66 let buffer_future = buffer_slice.map_async(wgpu::MapMode::Read);
67 device.device.poll(wgpu::Maintain::Wait);
68
69 if let Ok(()) = buffer_future.await {
71 let data = buffer_slice.get_mapped_range();
72 output.data = bytemuck::cast_slice::<u8, f32>(&data).to_vec();
73
74 drop(data);
76 output_buffer.unmap();
77
78 output
79 } else {
80 panic!("failed to run compute on gpu!")
81 }
82}