trueno/backends/gpu/device/linalg/
convolve2d.rs1use super::super::GpuDevice;
4#[cfg(any(feature = "gpu", feature = "gpu-wasm"))]
5use crate::backends::gpu::runtime;
6use crate::backends::gpu::shaders;
7
8impl GpuDevice {
9 #[cfg(all(feature = "gpu", not(target_arch = "wasm32")))]
23 #[allow(clippy::too_many_arguments)]
24 pub fn convolve2d(
25 &self,
26 input: &[f32],
27 kernel: &[f32],
28 result: &mut [f32],
29 input_rows: usize,
30 input_cols: usize,
31 kernel_rows: usize,
32 kernel_cols: usize,
33 ) -> Result<(), String> {
34 runtime::block_on(async {
35 self.convolve2d_async(
36 input,
37 kernel,
38 result,
39 input_rows,
40 input_cols,
41 kernel_rows,
42 kernel_cols,
43 )
44 .await
45 })
46 }
47
48 #[allow(clippy::too_many_arguments)]
50 pub async fn convolve2d_async(
51 &self,
52 input: &[f32],
53 kernel: &[f32],
54 result: &mut [f32],
55 input_rows: usize,
56 input_cols: usize,
57 kernel_rows: usize,
58 kernel_cols: usize,
59 ) -> Result<(), String> {
60 if kernel_rows > input_rows || kernel_cols > input_cols {
61 return Err(format!(
62 "Kernel size ({}x{}) larger than input ({}x{})",
63 kernel_rows, kernel_cols, input_rows, input_cols
64 ));
65 }
66 let output_rows = input_rows - kernel_rows + 1;
67 let output_cols = input_cols - kernel_cols + 1;
68
69 let shader = self.device.create_shader_module(wgpu::ShaderModuleDescriptor {
71 label: Some("Convolve2D Shader"),
72 source: wgpu::ShaderSource::Wgsl(shaders::CONVOLVE2D_SHADER.into()),
73 });
74
75 let input_buffer = self.device.create_buffer(&wgpu::BufferDescriptor {
77 label: Some("Input Image"),
78 size: std::mem::size_of_val(input) as u64,
79 usage: wgpu::BufferUsages::STORAGE | wgpu::BufferUsages::COPY_DST,
80 mapped_at_creation: false,
81 });
82
83 let kernel_buffer = self.device.create_buffer(&wgpu::BufferDescriptor {
84 label: Some("Kernel"),
85 size: std::mem::size_of_val(kernel) as u64,
86 usage: wgpu::BufferUsages::STORAGE | wgpu::BufferUsages::COPY_DST,
87 mapped_at_creation: false,
88 });
89
90 let output_buffer = self.device.create_buffer(&wgpu::BufferDescriptor {
91 label: Some("Output"),
92 size: std::mem::size_of_val(result) as u64,
93 usage: wgpu::BufferUsages::STORAGE
94 | wgpu::BufferUsages::COPY_SRC
95 | wgpu::BufferUsages::COPY_DST,
96 mapped_at_creation: false,
97 });
98
99 #[repr(C)]
101 #[derive(Copy, Clone, bytemuck::Pod, bytemuck::Zeroable)]
102 struct ConvDimensions {
103 input_rows: u32,
104 input_cols: u32,
105 kernel_rows: u32,
106 kernel_cols: u32,
107 output_rows: u32,
108 output_cols: u32,
109 }
110
111 let dims = ConvDimensions {
112 input_rows: input_rows as u32,
113 input_cols: input_cols as u32,
114 kernel_rows: kernel_rows as u32,
115 kernel_cols: kernel_cols as u32,
116 output_rows: output_rows as u32,
117 output_cols: output_cols as u32,
118 };
119
120 let dims_buffer = self.device.create_buffer(&wgpu::BufferDescriptor {
121 label: Some("Conv Dimensions"),
122 size: std::mem::size_of::<ConvDimensions>() as u64,
123 usage: wgpu::BufferUsages::UNIFORM | wgpu::BufferUsages::COPY_DST,
124 mapped_at_creation: false,
125 });
126
127 self.queue.write_buffer(&input_buffer, 0, bytemuck::cast_slice(input));
129 self.queue.write_buffer(&kernel_buffer, 0, bytemuck::cast_slice(kernel));
130 self.queue.write_buffer(&dims_buffer, 0, bytemuck::bytes_of(&dims));
131
132 let bind_group_layout =
134 self.device.create_bind_group_layout(&wgpu::BindGroupLayoutDescriptor {
135 label: Some("Convolve2D Bind Group Layout"),
136 entries: &[
137 wgpu::BindGroupLayoutEntry {
138 binding: 0,
139 visibility: wgpu::ShaderStages::COMPUTE,
140 ty: wgpu::BindingType::Buffer {
141 ty: wgpu::BufferBindingType::Storage { read_only: true },
142 has_dynamic_offset: false,
143 min_binding_size: None,
144 },
145 count: None,
146 },
147 wgpu::BindGroupLayoutEntry {
148 binding: 1,
149 visibility: wgpu::ShaderStages::COMPUTE,
150 ty: wgpu::BindingType::Buffer {
151 ty: wgpu::BufferBindingType::Storage { read_only: true },
152 has_dynamic_offset: false,
153 min_binding_size: None,
154 },
155 count: None,
156 },
157 wgpu::BindGroupLayoutEntry {
158 binding: 2,
159 visibility: wgpu::ShaderStages::COMPUTE,
160 ty: wgpu::BindingType::Buffer {
161 ty: wgpu::BufferBindingType::Storage { read_only: false },
162 has_dynamic_offset: false,
163 min_binding_size: None,
164 },
165 count: None,
166 },
167 wgpu::BindGroupLayoutEntry {
168 binding: 3,
169 visibility: wgpu::ShaderStages::COMPUTE,
170 ty: wgpu::BindingType::Buffer {
171 ty: wgpu::BufferBindingType::Uniform,
172 has_dynamic_offset: false,
173 min_binding_size: None,
174 },
175 count: None,
176 },
177 ],
178 });
179
180 let bind_group = self.device.create_bind_group(&wgpu::BindGroupDescriptor {
182 label: Some("Convolve2D Bind Group"),
183 layout: &bind_group_layout,
184 entries: &[
185 wgpu::BindGroupEntry { binding: 0, resource: input_buffer.as_entire_binding() },
186 wgpu::BindGroupEntry { binding: 1, resource: kernel_buffer.as_entire_binding() },
187 wgpu::BindGroupEntry { binding: 2, resource: output_buffer.as_entire_binding() },
188 wgpu::BindGroupEntry { binding: 3, resource: dims_buffer.as_entire_binding() },
189 ],
190 });
191
192 let pipeline_layout = self.device.create_pipeline_layout(&wgpu::PipelineLayoutDescriptor {
194 label: Some("Convolve2D Pipeline Layout"),
195 bind_group_layouts: &[&bind_group_layout],
196 push_constant_ranges: &[],
197 });
198
199 let pipeline = self.device.create_compute_pipeline(&wgpu::ComputePipelineDescriptor {
201 label: Some("Convolve2D Pipeline"),
202 layout: Some(&pipeline_layout),
203 module: &shader,
204 entry_point: Some("main"),
205 compilation_options: Default::default(),
206 cache: None,
207 });
208
209 let mut encoder = self.device.create_command_encoder(&wgpu::CommandEncoderDescriptor {
211 label: Some("Convolve2D Encoder"),
212 });
213
214 {
216 let mut compute_pass = encoder.begin_compute_pass(&wgpu::ComputePassDescriptor {
217 label: Some("Convolve2D Pass"),
218 timestamp_writes: None,
219 });
220
221 compute_pass.set_pipeline(&pipeline);
222 compute_pass.set_bind_group(0, &bind_group, &[]);
223
224 let workgroup_size_x = 16;
226 let workgroup_size_y = 16;
227 let num_workgroups_x = (output_rows as u32).div_ceil(workgroup_size_x);
228 let num_workgroups_y = (output_cols as u32).div_ceil(workgroup_size_y);
229 compute_pass.dispatch_workgroups(num_workgroups_x, num_workgroups_y, 1);
230 }
231
232 let staging_buffer = self.device.create_buffer(&wgpu::BufferDescriptor {
234 label: Some("Staging Buffer"),
235 size: std::mem::size_of_val(result) as u64,
236 usage: wgpu::BufferUsages::MAP_READ | wgpu::BufferUsages::COPY_DST,
237 mapped_at_creation: false,
238 });
239
240 encoder.copy_buffer_to_buffer(
242 &output_buffer,
243 0,
244 &staging_buffer,
245 0,
246 std::mem::size_of_val(result) as u64,
247 );
248
249 self.queue.submit(Some(encoder.finish()));
251
252 let buffer_slice = staging_buffer.slice(..);
254 let (sender, receiver) = futures_intrusive::channel::shared::oneshot_channel();
255 buffer_slice.map_async(wgpu::MapMode::Read, move |result| {
256 sender.send(result).expect("oneshot channel receiver dropped");
257 });
258
259 self.device.poll(wgpu::PollType::Wait { submission_index: None, timeout: None }).ok();
261
262 receiver
263 .receive()
264 .await
265 .ok_or("Failed to receive mapping result")?
266 .map_err(|e| format!("Buffer mapping failed: {:?}", e))?;
267
268 {
269 let data = buffer_slice.get_mapped_range();
270 let output_data: &[f32] = bytemuck::cast_slice(&data);
271 result.copy_from_slice(output_data);
272 }
273
274 staging_buffer.unmap();
275
276 Ok(())
277 }
278}