Skip to main content

trueno/backends/gpu/device/reductions/
tiled_2d.rs

1//! Generic 2D tiled reduction helper
2//!
3//! Shared implementation for tiled sum/max/min reductions on GPU.
4
5use super::super::GpuDevice;
6
7impl GpuDevice {
8    /// Generic 2D tiled reduction helper
9    #[allow(clippy::too_many_arguments)]
10    pub(super) async fn tiled_reduce_2d_async<F>(
11        &self,
12        data: &[f32],
13        width: usize,
14        height: usize,
15        shader_source: &str,
16        op_name: &str,
17        identity: f32,
18        combine: F,
19    ) -> Result<f32, String>
20    where
21        F: Fn(&[f32]) -> f32,
22    {
23        if data.is_empty() || width == 0 || height == 0 {
24            return Ok(identity);
25        }
26
27        // Calculate workgroup dimensions (16x16 tiles)
28        let workgroup_size_x: u32 = 16;
29        let workgroup_size_y: u32 = 16;
30        let num_workgroups_x = (width as u32).div_ceil(workgroup_size_x);
31        let num_workgroups_y = (height as u32).div_ceil(workgroup_size_y);
32        let total_workgroups = (num_workgroups_x * num_workgroups_y) as usize;
33
34        // Create shader module
35        let shader = self.device.create_shader_module(wgpu::ShaderModuleDescriptor {
36            label: Some(&format!("{} Shader", op_name)),
37            source: wgpu::ShaderSource::Wgsl(shader_source.into()),
38        });
39
40        // Create input buffer
41        let input_buffer = self.device.create_buffer(&wgpu::BufferDescriptor {
42            label: Some(&format!("{} Input", op_name)),
43            size: std::mem::size_of_val(data) as u64,
44            usage: wgpu::BufferUsages::STORAGE | wgpu::BufferUsages::COPY_DST,
45            mapped_at_creation: false,
46        });
47
48        // Create partial results buffer
49        let partial_buffer = self.device.create_buffer(&wgpu::BufferDescriptor {
50            label: Some(&format!("{} Partial Results", op_name)),
51            size: (total_workgroups * std::mem::size_of::<f32>()) as u64,
52            usage: wgpu::BufferUsages::STORAGE
53                | wgpu::BufferUsages::COPY_SRC
54                | wgpu::BufferUsages::COPY_DST,
55            mapped_at_creation: false,
56        });
57
58        // Dimensions uniform buffer
59        #[repr(C)]
60        #[derive(Copy, Clone, bytemuck::Pod, bytemuck::Zeroable)]
61        struct Dimensions {
62            width: u32,
63            height: u32,
64        }
65
66        let dims = Dimensions { width: width as u32, height: height as u32 };
67
68        let dims_buffer = self.device.create_buffer(&wgpu::BufferDescriptor {
69            label: Some(&format!("{} Dimensions", op_name)),
70            size: std::mem::size_of::<Dimensions>() as u64,
71            usage: wgpu::BufferUsages::UNIFORM | wgpu::BufferUsages::COPY_DST,
72            mapped_at_creation: false,
73        });
74
75        // Write data
76        self.queue.write_buffer(&input_buffer, 0, bytemuck::cast_slice(data));
77        self.queue.write_buffer(&dims_buffer, 0, bytemuck::bytes_of(&dims));
78
79        // Create bind group layout
80        let bind_group_layout =
81            self.device.create_bind_group_layout(&wgpu::BindGroupLayoutDescriptor {
82                label: Some(&format!("{} Bind Group Layout", op_name)),
83                entries: &[
84                    wgpu::BindGroupLayoutEntry {
85                        binding: 0,
86                        visibility: wgpu::ShaderStages::COMPUTE,
87                        ty: wgpu::BindingType::Buffer {
88                            ty: wgpu::BufferBindingType::Storage { read_only: true },
89                            has_dynamic_offset: false,
90                            min_binding_size: None,
91                        },
92                        count: None,
93                    },
94                    wgpu::BindGroupLayoutEntry {
95                        binding: 1,
96                        visibility: wgpu::ShaderStages::COMPUTE,
97                        ty: wgpu::BindingType::Buffer {
98                            ty: wgpu::BufferBindingType::Storage { read_only: false },
99                            has_dynamic_offset: false,
100                            min_binding_size: None,
101                        },
102                        count: None,
103                    },
104                    wgpu::BindGroupLayoutEntry {
105                        binding: 2,
106                        visibility: wgpu::ShaderStages::COMPUTE,
107                        ty: wgpu::BindingType::Buffer {
108                            ty: wgpu::BufferBindingType::Uniform,
109                            has_dynamic_offset: false,
110                            min_binding_size: None,
111                        },
112                        count: None,
113                    },
114                ],
115            });
116
117        // Create bind group
118        let bind_group = self.device.create_bind_group(&wgpu::BindGroupDescriptor {
119            label: Some(&format!("{} Bind Group", op_name)),
120            layout: &bind_group_layout,
121            entries: &[
122                wgpu::BindGroupEntry { binding: 0, resource: input_buffer.as_entire_binding() },
123                wgpu::BindGroupEntry { binding: 1, resource: partial_buffer.as_entire_binding() },
124                wgpu::BindGroupEntry { binding: 2, resource: dims_buffer.as_entire_binding() },
125            ],
126        });
127
128        // Create pipeline
129        let pipeline_layout = self.device.create_pipeline_layout(&wgpu::PipelineLayoutDescriptor {
130            label: Some(&format!("{} Pipeline Layout", op_name)),
131            bind_group_layouts: &[&bind_group_layout],
132            push_constant_ranges: &[],
133        });
134
135        let pipeline = self.device.create_compute_pipeline(&wgpu::ComputePipelineDescriptor {
136            label: Some(&format!("{} Pipeline", op_name)),
137            layout: Some(&pipeline_layout),
138            module: &shader,
139            entry_point: Some("main"),
140            compilation_options: Default::default(),
141            cache: None,
142        });
143
144        // Create staging buffer
145        let staging_buffer = self.device.create_buffer(&wgpu::BufferDescriptor {
146            label: Some(&format!("{} Staging", op_name)),
147            size: (total_workgroups * std::mem::size_of::<f32>()) as u64,
148            usage: wgpu::BufferUsages::MAP_READ | wgpu::BufferUsages::COPY_DST,
149            mapped_at_creation: false,
150        });
151
152        // Create command encoder
153        let mut encoder = self.device.create_command_encoder(&wgpu::CommandEncoderDescriptor {
154            label: Some(&format!("{} Encoder", op_name)),
155        });
156
157        {
158            let mut compute_pass = encoder.begin_compute_pass(&wgpu::ComputePassDescriptor {
159                label: Some(&format!("{} Pass", op_name)),
160                timestamp_writes: None,
161            });
162            compute_pass.set_pipeline(&pipeline);
163            compute_pass.set_bind_group(0, &bind_group, &[]);
164            compute_pass.dispatch_workgroups(num_workgroups_x, num_workgroups_y, 1);
165        }
166
167        // Copy result to staging buffer
168        encoder.copy_buffer_to_buffer(
169            &partial_buffer,
170            0,
171            &staging_buffer,
172            0,
173            (total_workgroups * std::mem::size_of::<f32>()) as u64,
174        );
175
176        // Submit commands
177        self.queue.submit(Some(encoder.finish()));
178
179        // Read back results
180        let buffer_slice = staging_buffer.slice(..);
181        let (sender, receiver) = futures_intrusive::channel::shared::oneshot_channel();
182        buffer_slice.map_async(wgpu::MapMode::Read, move |result| {
183            sender.send(result).ok();
184        });
185
186        // Poll device
187        self.device.poll(wgpu::PollType::Wait { submission_index: None, timeout: None }).ok();
188
189        receiver
190            .receive()
191            .await
192            .ok_or("Failed to receive mapping result")?
193            .map_err(|e| format!("Buffer mapping failed: {:?}", e))?;
194
195        let final_result = {
196            let data = buffer_slice.get_mapped_range();
197            let partials: &[f32] = bytemuck::cast_slice(&data);
198            combine(partials)
199        };
200
201        staging_buffer.unmap();
202
203        Ok(final_result)
204    }
205}