Skip to main content

trueno/backends/gpu/device/linalg/
convolve2d.rs

1//! GPU 2D convolution operations
2
3use super::super::GpuDevice;
4#[cfg(any(feature = "gpu", feature = "gpu-wasm"))]
5use crate::backends::gpu::runtime;
6use crate::backends::gpu::shaders;
7
8impl GpuDevice {
9    /// Perform 2D convolution on GPU (sync, native only)
10    ///
11    /// # Arguments
12    ///
13    /// * `input` - Input image (row-major)
14    /// * `kernel` - Convolution kernel (row-major)
15    /// * `result` - Output buffer (row-major)
16    /// * `input_rows` - Number of rows in input
17    /// * `input_cols` - Number of columns in input
18    /// * `kernel_rows` - Number of rows in kernel
19    /// * `kernel_cols` - Number of columns in kernel
20    ///
21    /// Output dimensions: (input_rows - kernel_rows + 1) x (input_cols - kernel_cols + 1)
22    #[cfg(all(feature = "gpu", not(target_arch = "wasm32")))]
23    #[allow(clippy::too_many_arguments)]
24    pub fn convolve2d(
25        &self,
26        input: &[f32],
27        kernel: &[f32],
28        result: &mut [f32],
29        input_rows: usize,
30        input_cols: usize,
31        kernel_rows: usize,
32        kernel_cols: usize,
33    ) -> Result<(), String> {
34        runtime::block_on(async {
35            self.convolve2d_async(
36                input,
37                kernel,
38                result,
39                input_rows,
40                input_cols,
41                kernel_rows,
42                kernel_cols,
43            )
44            .await
45        })
46    }
47
48    /// Perform 2D convolution on GPU (async, works on all platforms)
49    #[allow(clippy::too_many_arguments)]
50    pub async fn convolve2d_async(
51        &self,
52        input: &[f32],
53        kernel: &[f32],
54        result: &mut [f32],
55        input_rows: usize,
56        input_cols: usize,
57        kernel_rows: usize,
58        kernel_cols: usize,
59    ) -> Result<(), String> {
60        if kernel_rows > input_rows || kernel_cols > input_cols {
61            return Err(format!(
62                "Kernel size ({}x{}) larger than input ({}x{})",
63                kernel_rows, kernel_cols, input_rows, input_cols
64            ));
65        }
66        let output_rows = input_rows - kernel_rows + 1;
67        let output_cols = input_cols - kernel_cols + 1;
68
69        // Create shader module
70        let shader = self.device.create_shader_module(wgpu::ShaderModuleDescriptor {
71            label: Some("Convolve2D Shader"),
72            source: wgpu::ShaderSource::Wgsl(shaders::CONVOLVE2D_SHADER.into()),
73        });
74
75        // Create buffers
76        let input_buffer = self.device.create_buffer(&wgpu::BufferDescriptor {
77            label: Some("Input Image"),
78            size: std::mem::size_of_val(input) as u64,
79            usage: wgpu::BufferUsages::STORAGE | wgpu::BufferUsages::COPY_DST,
80            mapped_at_creation: false,
81        });
82
83        let kernel_buffer = self.device.create_buffer(&wgpu::BufferDescriptor {
84            label: Some("Kernel"),
85            size: std::mem::size_of_val(kernel) as u64,
86            usage: wgpu::BufferUsages::STORAGE | wgpu::BufferUsages::COPY_DST,
87            mapped_at_creation: false,
88        });
89
90        let output_buffer = self.device.create_buffer(&wgpu::BufferDescriptor {
91            label: Some("Output"),
92            size: std::mem::size_of_val(result) as u64,
93            usage: wgpu::BufferUsages::STORAGE
94                | wgpu::BufferUsages::COPY_SRC
95                | wgpu::BufferUsages::COPY_DST,
96            mapped_at_creation: false,
97        });
98
99        // Dimensions uniform buffer
100        #[repr(C)]
101        #[derive(Copy, Clone, bytemuck::Pod, bytemuck::Zeroable)]
102        struct ConvDimensions {
103            input_rows: u32,
104            input_cols: u32,
105            kernel_rows: u32,
106            kernel_cols: u32,
107            output_rows: u32,
108            output_cols: u32,
109        }
110
111        let dims = ConvDimensions {
112            input_rows: input_rows as u32,
113            input_cols: input_cols as u32,
114            kernel_rows: kernel_rows as u32,
115            kernel_cols: kernel_cols as u32,
116            output_rows: output_rows as u32,
117            output_cols: output_cols as u32,
118        };
119
120        let dims_buffer = self.device.create_buffer(&wgpu::BufferDescriptor {
121            label: Some("Conv Dimensions"),
122            size: std::mem::size_of::<ConvDimensions>() as u64,
123            usage: wgpu::BufferUsages::UNIFORM | wgpu::BufferUsages::COPY_DST,
124            mapped_at_creation: false,
125        });
126
127        // Write data to buffers
128        self.queue.write_buffer(&input_buffer, 0, bytemuck::cast_slice(input));
129        self.queue.write_buffer(&kernel_buffer, 0, bytemuck::cast_slice(kernel));
130        self.queue.write_buffer(&dims_buffer, 0, bytemuck::bytes_of(&dims));
131
132        // Create bind group layout
133        let bind_group_layout =
134            self.device.create_bind_group_layout(&wgpu::BindGroupLayoutDescriptor {
135                label: Some("Convolve2D Bind Group Layout"),
136                entries: &[
137                    wgpu::BindGroupLayoutEntry {
138                        binding: 0,
139                        visibility: wgpu::ShaderStages::COMPUTE,
140                        ty: wgpu::BindingType::Buffer {
141                            ty: wgpu::BufferBindingType::Storage { read_only: true },
142                            has_dynamic_offset: false,
143                            min_binding_size: None,
144                        },
145                        count: None,
146                    },
147                    wgpu::BindGroupLayoutEntry {
148                        binding: 1,
149                        visibility: wgpu::ShaderStages::COMPUTE,
150                        ty: wgpu::BindingType::Buffer {
151                            ty: wgpu::BufferBindingType::Storage { read_only: true },
152                            has_dynamic_offset: false,
153                            min_binding_size: None,
154                        },
155                        count: None,
156                    },
157                    wgpu::BindGroupLayoutEntry {
158                        binding: 2,
159                        visibility: wgpu::ShaderStages::COMPUTE,
160                        ty: wgpu::BindingType::Buffer {
161                            ty: wgpu::BufferBindingType::Storage { read_only: false },
162                            has_dynamic_offset: false,
163                            min_binding_size: None,
164                        },
165                        count: None,
166                    },
167                    wgpu::BindGroupLayoutEntry {
168                        binding: 3,
169                        visibility: wgpu::ShaderStages::COMPUTE,
170                        ty: wgpu::BindingType::Buffer {
171                            ty: wgpu::BufferBindingType::Uniform,
172                            has_dynamic_offset: false,
173                            min_binding_size: None,
174                        },
175                        count: None,
176                    },
177                ],
178            });
179
180        // Create bind group
181        let bind_group = self.device.create_bind_group(&wgpu::BindGroupDescriptor {
182            label: Some("Convolve2D Bind Group"),
183            layout: &bind_group_layout,
184            entries: &[
185                wgpu::BindGroupEntry { binding: 0, resource: input_buffer.as_entire_binding() },
186                wgpu::BindGroupEntry { binding: 1, resource: kernel_buffer.as_entire_binding() },
187                wgpu::BindGroupEntry { binding: 2, resource: output_buffer.as_entire_binding() },
188                wgpu::BindGroupEntry { binding: 3, resource: dims_buffer.as_entire_binding() },
189            ],
190        });
191
192        // Create pipeline layout
193        let pipeline_layout = self.device.create_pipeline_layout(&wgpu::PipelineLayoutDescriptor {
194            label: Some("Convolve2D Pipeline Layout"),
195            bind_group_layouts: &[&bind_group_layout],
196            push_constant_ranges: &[],
197        });
198
199        // Create compute pipeline
200        let pipeline = self.device.create_compute_pipeline(&wgpu::ComputePipelineDescriptor {
201            label: Some("Convolve2D Pipeline"),
202            layout: Some(&pipeline_layout),
203            module: &shader,
204            entry_point: Some("main"),
205            compilation_options: Default::default(),
206            cache: None,
207        });
208
209        // Create command encoder
210        let mut encoder = self.device.create_command_encoder(&wgpu::CommandEncoderDescriptor {
211            label: Some("Convolve2D Encoder"),
212        });
213
214        // Compute pass
215        {
216            let mut compute_pass = encoder.begin_compute_pass(&wgpu::ComputePassDescriptor {
217                label: Some("Convolve2D Pass"),
218                timestamp_writes: None,
219            });
220
221            compute_pass.set_pipeline(&pipeline);
222            compute_pass.set_bind_group(0, &bind_group, &[]);
223
224            // Dispatch workgroups: 16x16 threads per workgroup
225            let workgroup_size_x = 16;
226            let workgroup_size_y = 16;
227            let num_workgroups_x = (output_rows as u32).div_ceil(workgroup_size_x);
228            let num_workgroups_y = (output_cols as u32).div_ceil(workgroup_size_y);
229            compute_pass.dispatch_workgroups(num_workgroups_x, num_workgroups_y, 1);
230        }
231
232        // Create staging buffer for result readback
233        let staging_buffer = self.device.create_buffer(&wgpu::BufferDescriptor {
234            label: Some("Staging Buffer"),
235            size: std::mem::size_of_val(result) as u64,
236            usage: wgpu::BufferUsages::MAP_READ | wgpu::BufferUsages::COPY_DST,
237            mapped_at_creation: false,
238        });
239
240        // Copy output to staging buffer
241        encoder.copy_buffer_to_buffer(
242            &output_buffer,
243            0,
244            &staging_buffer,
245            0,
246            std::mem::size_of_val(result) as u64,
247        );
248
249        // Submit commands
250        self.queue.submit(Some(encoder.finish()));
251
252        // Read result back
253        let buffer_slice = staging_buffer.slice(..);
254        let (sender, receiver) = futures_intrusive::channel::shared::oneshot_channel();
255        buffer_slice.map_async(wgpu::MapMode::Read, move |result| {
256            sender.send(result).expect("oneshot channel receiver dropped");
257        });
258
259        // Poll device to ensure GPU work completes and callbacks are invoked
260        self.device.poll(wgpu::PollType::Wait { submission_index: None, timeout: None }).ok();
261
262        receiver
263            .receive()
264            .await
265            .ok_or("Failed to receive mapping result")?
266            .map_err(|e| format!("Buffer mapping failed: {:?}", e))?;
267
268        {
269            let data = buffer_slice.get_mapped_range();
270            let output_data: &[f32] = bytemuck::cast_slice(&data);
271            result.copy_from_slice(output_data);
272        }
273
274        staging_buffer.unmap();
275
276        Ok(())
277    }
278}