Skip to main content

trueno/backends/gpu/device/reductions/
reduce_1d.rs

1//! 1D parallel reduction operations (max, sum)
2//!
3//! These are internal helpers used by activation functions (softmax, log_softmax).
4
5use super::super::super::shaders;
6use super::super::GpuDevice;
7
8impl GpuDevice {
9    /// Helper: Parallel max reduction
10    pub(in crate::backends::gpu::device) async fn reduce_max(
11        &self,
12        input: &[f32],
13    ) -> Result<f32, String> {
14        let len = input.len();
15        let workgroup_size = 256;
16        let num_workgroups = (len as u32).div_ceil(workgroup_size);
17
18        // Create shader module
19        let shader = self.device.create_shader_module(wgpu::ShaderModuleDescriptor {
20            label: Some("Max Reduction Shader"),
21            source: wgpu::ShaderSource::Wgsl(shaders::MAX_REDUCTION_SHADER.into()),
22        });
23
24        // Create input buffer
25        let input_buffer = self.device.create_buffer(&wgpu::BufferDescriptor {
26            label: Some("Max Reduction Input"),
27            size: std::mem::size_of_val(input) as u64,
28            usage: wgpu::BufferUsages::STORAGE | wgpu::BufferUsages::COPY_DST,
29            mapped_at_creation: false,
30        });
31
32        // Result buffer for partial maxes
33        let partial_results = vec![f32::NEG_INFINITY; num_workgroups as usize];
34        let result_buffer = self.device.create_buffer(&wgpu::BufferDescriptor {
35            label: Some("Max Partial Results"),
36            size: std::mem::size_of_val(partial_results.as_slice()) as u64,
37            usage: wgpu::BufferUsages::STORAGE
38                | wgpu::BufferUsages::COPY_SRC
39                | wgpu::BufferUsages::COPY_DST,
40            mapped_at_creation: false,
41        });
42
43        self.queue.write_buffer(&input_buffer, 0, bytemuck::cast_slice(input));
44
45        // Create bind group layout
46        let bind_group_layout =
47            self.device.create_bind_group_layout(&wgpu::BindGroupLayoutDescriptor {
48                label: Some("Max Reduction Bind Group Layout"),
49                entries: &[
50                    wgpu::BindGroupLayoutEntry {
51                        binding: 0,
52                        visibility: wgpu::ShaderStages::COMPUTE,
53                        ty: wgpu::BindingType::Buffer {
54                            ty: wgpu::BufferBindingType::Storage { read_only: true },
55                            has_dynamic_offset: false,
56                            min_binding_size: None,
57                        },
58                        count: None,
59                    },
60                    wgpu::BindGroupLayoutEntry {
61                        binding: 1,
62                        visibility: wgpu::ShaderStages::COMPUTE,
63                        ty: wgpu::BindingType::Buffer {
64                            ty: wgpu::BufferBindingType::Storage { read_only: false },
65                            has_dynamic_offset: false,
66                            min_binding_size: None,
67                        },
68                        count: None,
69                    },
70                ],
71            });
72
73        let bind_group = self.device.create_bind_group(&wgpu::BindGroupDescriptor {
74            label: Some("Max Reduction Bind Group"),
75            layout: &bind_group_layout,
76            entries: &[
77                wgpu::BindGroupEntry { binding: 0, resource: input_buffer.as_entire_binding() },
78                wgpu::BindGroupEntry { binding: 1, resource: result_buffer.as_entire_binding() },
79            ],
80        });
81
82        let pipeline_layout = self.device.create_pipeline_layout(&wgpu::PipelineLayoutDescriptor {
83            label: Some("Max Reduction Pipeline Layout"),
84            bind_group_layouts: &[&bind_group_layout],
85            push_constant_ranges: &[],
86        });
87
88        let pipeline = self.device.create_compute_pipeline(&wgpu::ComputePipelineDescriptor {
89            label: Some("Max Reduction Pipeline"),
90            layout: Some(&pipeline_layout),
91            module: &shader,
92            entry_point: Some("main"),
93            compilation_options: Default::default(),
94            cache: None,
95        });
96
97        let mut encoder = self.device.create_command_encoder(&wgpu::CommandEncoderDescriptor {
98            label: Some("Max Reduction Encoder"),
99        });
100
101        {
102            let mut compute_pass = encoder.begin_compute_pass(&wgpu::ComputePassDescriptor {
103                label: Some("Max Reduction Pass"),
104                timestamp_writes: None,
105            });
106
107            compute_pass.set_pipeline(&pipeline);
108            compute_pass.set_bind_group(0, &bind_group, &[]);
109            compute_pass.dispatch_workgroups(num_workgroups, 1, 1);
110        }
111
112        // Create staging buffer
113        let staging_buffer = self.device.create_buffer(&wgpu::BufferDescriptor {
114            label: Some("Max Staging Buffer"),
115            size: std::mem::size_of_val(partial_results.as_slice()) as u64,
116            usage: wgpu::BufferUsages::MAP_READ | wgpu::BufferUsages::COPY_DST,
117            mapped_at_creation: false,
118        });
119
120        encoder.copy_buffer_to_buffer(
121            &result_buffer,
122            0,
123            &staging_buffer,
124            0,
125            std::mem::size_of_val(partial_results.as_slice()) as u64,
126        );
127
128        self.queue.submit(Some(encoder.finish()));
129
130        let buffer_slice = staging_buffer.slice(..);
131        let (sender, receiver) = futures_intrusive::channel::shared::oneshot_channel();
132        buffer_slice.map_async(wgpu::MapMode::Read, move |result| {
133            sender.send(result).ok();
134        });
135
136        // Poll device to ensure GPU work completes and callbacks are invoked
137        self.device.poll(wgpu::PollType::Wait { submission_index: None, timeout: None }).ok();
138        receiver
139            .receive()
140            .await
141            .ok_or("Channel receive failed")?
142            .map_err(|e| format!("Buffer map failed: {:?}", e))?;
143
144        let data = buffer_slice.get_mapped_range();
145        let result: Vec<f32> = bytemuck::cast_slice(&data).to_vec();
146        drop(data);
147        staging_buffer.unmap();
148
149        // Final reduction on CPU
150        Ok(result.iter().copied().fold(f32::NEG_INFINITY, f32::max))
151    }
152
153    /// Helper: Parallel sum reduction
154    pub(in crate::backends::gpu::device) async fn reduce_sum(
155        &self,
156        input: &[f32],
157    ) -> Result<f32, String> {
158        let len = input.len();
159        let workgroup_size = 256;
160        let num_workgroups = (len as u32).div_ceil(workgroup_size);
161
162        let shader = self.device.create_shader_module(wgpu::ShaderModuleDescriptor {
163            label: Some("Sum Reduction Shader"),
164            source: wgpu::ShaderSource::Wgsl(shaders::SUM_REDUCTION_SHADER.into()),
165        });
166
167        let input_buffer = self.device.create_buffer(&wgpu::BufferDescriptor {
168            label: Some("Sum Reduction Input"),
169            size: std::mem::size_of_val(input) as u64,
170            usage: wgpu::BufferUsages::STORAGE | wgpu::BufferUsages::COPY_DST,
171            mapped_at_creation: false,
172        });
173
174        let partial_results = vec![0.0f32; num_workgroups as usize];
175        let result_buffer = self.device.create_buffer(&wgpu::BufferDescriptor {
176            label: Some("Sum Partial Results"),
177            size: std::mem::size_of_val(partial_results.as_slice()) as u64,
178            usage: wgpu::BufferUsages::STORAGE
179                | wgpu::BufferUsages::COPY_SRC
180                | wgpu::BufferUsages::COPY_DST,
181            mapped_at_creation: false,
182        });
183
184        self.queue.write_buffer(&input_buffer, 0, bytemuck::cast_slice(input));
185
186        let bind_group_layout =
187            self.device.create_bind_group_layout(&wgpu::BindGroupLayoutDescriptor {
188                label: Some("Sum Reduction Bind Group Layout"),
189                entries: &[
190                    wgpu::BindGroupLayoutEntry {
191                        binding: 0,
192                        visibility: wgpu::ShaderStages::COMPUTE,
193                        ty: wgpu::BindingType::Buffer {
194                            ty: wgpu::BufferBindingType::Storage { read_only: true },
195                            has_dynamic_offset: false,
196                            min_binding_size: None,
197                        },
198                        count: None,
199                    },
200                    wgpu::BindGroupLayoutEntry {
201                        binding: 1,
202                        visibility: wgpu::ShaderStages::COMPUTE,
203                        ty: wgpu::BindingType::Buffer {
204                            ty: wgpu::BufferBindingType::Storage { read_only: false },
205                            has_dynamic_offset: false,
206                            min_binding_size: None,
207                        },
208                        count: None,
209                    },
210                ],
211            });
212
213        let bind_group = self.device.create_bind_group(&wgpu::BindGroupDescriptor {
214            label: Some("Sum Reduction Bind Group"),
215            layout: &bind_group_layout,
216            entries: &[
217                wgpu::BindGroupEntry { binding: 0, resource: input_buffer.as_entire_binding() },
218                wgpu::BindGroupEntry { binding: 1, resource: result_buffer.as_entire_binding() },
219            ],
220        });
221
222        let pipeline_layout = self.device.create_pipeline_layout(&wgpu::PipelineLayoutDescriptor {
223            label: Some("Sum Reduction Pipeline Layout"),
224            bind_group_layouts: &[&bind_group_layout],
225            push_constant_ranges: &[],
226        });
227
228        let pipeline = self.device.create_compute_pipeline(&wgpu::ComputePipelineDescriptor {
229            label: Some("Sum Reduction Pipeline"),
230            layout: Some(&pipeline_layout),
231            module: &shader,
232            entry_point: Some("main"),
233            compilation_options: Default::default(),
234            cache: None,
235        });
236
237        let mut encoder = self.device.create_command_encoder(&wgpu::CommandEncoderDescriptor {
238            label: Some("Sum Reduction Encoder"),
239        });
240
241        {
242            let mut compute_pass = encoder.begin_compute_pass(&wgpu::ComputePassDescriptor {
243                label: Some("Sum Reduction Pass"),
244                timestamp_writes: None,
245            });
246
247            compute_pass.set_pipeline(&pipeline);
248            compute_pass.set_bind_group(0, &bind_group, &[]);
249            compute_pass.dispatch_workgroups(num_workgroups, 1, 1);
250        }
251
252        let staging_buffer = self.device.create_buffer(&wgpu::BufferDescriptor {
253            label: Some("Sum Staging Buffer"),
254            size: std::mem::size_of_val(partial_results.as_slice()) as u64,
255            usage: wgpu::BufferUsages::MAP_READ | wgpu::BufferUsages::COPY_DST,
256            mapped_at_creation: false,
257        });
258
259        encoder.copy_buffer_to_buffer(
260            &result_buffer,
261            0,
262            &staging_buffer,
263            0,
264            std::mem::size_of_val(partial_results.as_slice()) as u64,
265        );
266
267        self.queue.submit(Some(encoder.finish()));
268
269        let buffer_slice = staging_buffer.slice(..);
270        let (sender, receiver) = futures_intrusive::channel::shared::oneshot_channel();
271        buffer_slice.map_async(wgpu::MapMode::Read, move |result| {
272            sender.send(result).ok();
273        });
274
275        // Poll device to ensure GPU work completes and callbacks are invoked
276        self.device.poll(wgpu::PollType::Wait { submission_index: None, timeout: None }).ok();
277        receiver
278            .receive()
279            .await
280            .ok_or("Channel receive failed")?
281            .map_err(|e| format!("Buffer map failed: {:?}", e))?;
282
283        let data = buffer_slice.get_mapped_range();
284        let result: Vec<f32> = bytemuck::cast_slice(&data).to_vec();
285        drop(data);
286        staging_buffer.unmap();
287
288        // Final reduction on CPU
289        Ok(result.iter().sum())
290    }
291}