trueno/backends/gpu/device/reductions/
reduce_1d.rs1use super::super::super::shaders;
6use super::super::GpuDevice;
7
8impl GpuDevice {
9 pub(in crate::backends::gpu::device) async fn reduce_max(
11 &self,
12 input: &[f32],
13 ) -> Result<f32, String> {
14 let len = input.len();
15 let workgroup_size = 256;
16 let num_workgroups = (len as u32).div_ceil(workgroup_size);
17
18 let shader = self.device.create_shader_module(wgpu::ShaderModuleDescriptor {
20 label: Some("Max Reduction Shader"),
21 source: wgpu::ShaderSource::Wgsl(shaders::MAX_REDUCTION_SHADER.into()),
22 });
23
24 let input_buffer = self.device.create_buffer(&wgpu::BufferDescriptor {
26 label: Some("Max Reduction Input"),
27 size: std::mem::size_of_val(input) as u64,
28 usage: wgpu::BufferUsages::STORAGE | wgpu::BufferUsages::COPY_DST,
29 mapped_at_creation: false,
30 });
31
32 let partial_results = vec![f32::NEG_INFINITY; num_workgroups as usize];
34 let result_buffer = self.device.create_buffer(&wgpu::BufferDescriptor {
35 label: Some("Max Partial Results"),
36 size: std::mem::size_of_val(partial_results.as_slice()) as u64,
37 usage: wgpu::BufferUsages::STORAGE
38 | wgpu::BufferUsages::COPY_SRC
39 | wgpu::BufferUsages::COPY_DST,
40 mapped_at_creation: false,
41 });
42
43 self.queue.write_buffer(&input_buffer, 0, bytemuck::cast_slice(input));
44
45 let bind_group_layout =
47 self.device.create_bind_group_layout(&wgpu::BindGroupLayoutDescriptor {
48 label: Some("Max Reduction Bind Group Layout"),
49 entries: &[
50 wgpu::BindGroupLayoutEntry {
51 binding: 0,
52 visibility: wgpu::ShaderStages::COMPUTE,
53 ty: wgpu::BindingType::Buffer {
54 ty: wgpu::BufferBindingType::Storage { read_only: true },
55 has_dynamic_offset: false,
56 min_binding_size: None,
57 },
58 count: None,
59 },
60 wgpu::BindGroupLayoutEntry {
61 binding: 1,
62 visibility: wgpu::ShaderStages::COMPUTE,
63 ty: wgpu::BindingType::Buffer {
64 ty: wgpu::BufferBindingType::Storage { read_only: false },
65 has_dynamic_offset: false,
66 min_binding_size: None,
67 },
68 count: None,
69 },
70 ],
71 });
72
73 let bind_group = self.device.create_bind_group(&wgpu::BindGroupDescriptor {
74 label: Some("Max Reduction Bind Group"),
75 layout: &bind_group_layout,
76 entries: &[
77 wgpu::BindGroupEntry { binding: 0, resource: input_buffer.as_entire_binding() },
78 wgpu::BindGroupEntry { binding: 1, resource: result_buffer.as_entire_binding() },
79 ],
80 });
81
82 let pipeline_layout = self.device.create_pipeline_layout(&wgpu::PipelineLayoutDescriptor {
83 label: Some("Max Reduction Pipeline Layout"),
84 bind_group_layouts: &[&bind_group_layout],
85 push_constant_ranges: &[],
86 });
87
88 let pipeline = self.device.create_compute_pipeline(&wgpu::ComputePipelineDescriptor {
89 label: Some("Max Reduction Pipeline"),
90 layout: Some(&pipeline_layout),
91 module: &shader,
92 entry_point: Some("main"),
93 compilation_options: Default::default(),
94 cache: None,
95 });
96
97 let mut encoder = self.device.create_command_encoder(&wgpu::CommandEncoderDescriptor {
98 label: Some("Max Reduction Encoder"),
99 });
100
101 {
102 let mut compute_pass = encoder.begin_compute_pass(&wgpu::ComputePassDescriptor {
103 label: Some("Max Reduction Pass"),
104 timestamp_writes: None,
105 });
106
107 compute_pass.set_pipeline(&pipeline);
108 compute_pass.set_bind_group(0, &bind_group, &[]);
109 compute_pass.dispatch_workgroups(num_workgroups, 1, 1);
110 }
111
112 let staging_buffer = self.device.create_buffer(&wgpu::BufferDescriptor {
114 label: Some("Max Staging Buffer"),
115 size: std::mem::size_of_val(partial_results.as_slice()) as u64,
116 usage: wgpu::BufferUsages::MAP_READ | wgpu::BufferUsages::COPY_DST,
117 mapped_at_creation: false,
118 });
119
120 encoder.copy_buffer_to_buffer(
121 &result_buffer,
122 0,
123 &staging_buffer,
124 0,
125 std::mem::size_of_val(partial_results.as_slice()) as u64,
126 );
127
128 self.queue.submit(Some(encoder.finish()));
129
130 let buffer_slice = staging_buffer.slice(..);
131 let (sender, receiver) = futures_intrusive::channel::shared::oneshot_channel();
132 buffer_slice.map_async(wgpu::MapMode::Read, move |result| {
133 sender.send(result).ok();
134 });
135
136 self.device.poll(wgpu::PollType::Wait { submission_index: None, timeout: None }).ok();
138 receiver
139 .receive()
140 .await
141 .ok_or("Channel receive failed")?
142 .map_err(|e| format!("Buffer map failed: {:?}", e))?;
143
144 let data = buffer_slice.get_mapped_range();
145 let result: Vec<f32> = bytemuck::cast_slice(&data).to_vec();
146 drop(data);
147 staging_buffer.unmap();
148
149 Ok(result.iter().copied().fold(f32::NEG_INFINITY, f32::max))
151 }
152
153 pub(in crate::backends::gpu::device) async fn reduce_sum(
155 &self,
156 input: &[f32],
157 ) -> Result<f32, String> {
158 let len = input.len();
159 let workgroup_size = 256;
160 let num_workgroups = (len as u32).div_ceil(workgroup_size);
161
162 let shader = self.device.create_shader_module(wgpu::ShaderModuleDescriptor {
163 label: Some("Sum Reduction Shader"),
164 source: wgpu::ShaderSource::Wgsl(shaders::SUM_REDUCTION_SHADER.into()),
165 });
166
167 let input_buffer = self.device.create_buffer(&wgpu::BufferDescriptor {
168 label: Some("Sum Reduction Input"),
169 size: std::mem::size_of_val(input) as u64,
170 usage: wgpu::BufferUsages::STORAGE | wgpu::BufferUsages::COPY_DST,
171 mapped_at_creation: false,
172 });
173
174 let partial_results = vec![0.0f32; num_workgroups as usize];
175 let result_buffer = self.device.create_buffer(&wgpu::BufferDescriptor {
176 label: Some("Sum Partial Results"),
177 size: std::mem::size_of_val(partial_results.as_slice()) as u64,
178 usage: wgpu::BufferUsages::STORAGE
179 | wgpu::BufferUsages::COPY_SRC
180 | wgpu::BufferUsages::COPY_DST,
181 mapped_at_creation: false,
182 });
183
184 self.queue.write_buffer(&input_buffer, 0, bytemuck::cast_slice(input));
185
186 let bind_group_layout =
187 self.device.create_bind_group_layout(&wgpu::BindGroupLayoutDescriptor {
188 label: Some("Sum Reduction Bind Group Layout"),
189 entries: &[
190 wgpu::BindGroupLayoutEntry {
191 binding: 0,
192 visibility: wgpu::ShaderStages::COMPUTE,
193 ty: wgpu::BindingType::Buffer {
194 ty: wgpu::BufferBindingType::Storage { read_only: true },
195 has_dynamic_offset: false,
196 min_binding_size: None,
197 },
198 count: None,
199 },
200 wgpu::BindGroupLayoutEntry {
201 binding: 1,
202 visibility: wgpu::ShaderStages::COMPUTE,
203 ty: wgpu::BindingType::Buffer {
204 ty: wgpu::BufferBindingType::Storage { read_only: false },
205 has_dynamic_offset: false,
206 min_binding_size: None,
207 },
208 count: None,
209 },
210 ],
211 });
212
213 let bind_group = self.device.create_bind_group(&wgpu::BindGroupDescriptor {
214 label: Some("Sum Reduction Bind Group"),
215 layout: &bind_group_layout,
216 entries: &[
217 wgpu::BindGroupEntry { binding: 0, resource: input_buffer.as_entire_binding() },
218 wgpu::BindGroupEntry { binding: 1, resource: result_buffer.as_entire_binding() },
219 ],
220 });
221
222 let pipeline_layout = self.device.create_pipeline_layout(&wgpu::PipelineLayoutDescriptor {
223 label: Some("Sum Reduction Pipeline Layout"),
224 bind_group_layouts: &[&bind_group_layout],
225 push_constant_ranges: &[],
226 });
227
228 let pipeline = self.device.create_compute_pipeline(&wgpu::ComputePipelineDescriptor {
229 label: Some("Sum Reduction Pipeline"),
230 layout: Some(&pipeline_layout),
231 module: &shader,
232 entry_point: Some("main"),
233 compilation_options: Default::default(),
234 cache: None,
235 });
236
237 let mut encoder = self.device.create_command_encoder(&wgpu::CommandEncoderDescriptor {
238 label: Some("Sum Reduction Encoder"),
239 });
240
241 {
242 let mut compute_pass = encoder.begin_compute_pass(&wgpu::ComputePassDescriptor {
243 label: Some("Sum Reduction Pass"),
244 timestamp_writes: None,
245 });
246
247 compute_pass.set_pipeline(&pipeline);
248 compute_pass.set_bind_group(0, &bind_group, &[]);
249 compute_pass.dispatch_workgroups(num_workgroups, 1, 1);
250 }
251
252 let staging_buffer = self.device.create_buffer(&wgpu::BufferDescriptor {
253 label: Some("Sum Staging Buffer"),
254 size: std::mem::size_of_val(partial_results.as_slice()) as u64,
255 usage: wgpu::BufferUsages::MAP_READ | wgpu::BufferUsages::COPY_DST,
256 mapped_at_creation: false,
257 });
258
259 encoder.copy_buffer_to_buffer(
260 &result_buffer,
261 0,
262 &staging_buffer,
263 0,
264 std::mem::size_of_val(partial_results.as_slice()) as u64,
265 );
266
267 self.queue.submit(Some(encoder.finish()));
268
269 let buffer_slice = staging_buffer.slice(..);
270 let (sender, receiver) = futures_intrusive::channel::shared::oneshot_channel();
271 buffer_slice.map_async(wgpu::MapMode::Read, move |result| {
272 sender.send(result).ok();
273 });
274
275 self.device.poll(wgpu::PollType::Wait { submission_index: None, timeout: None }).ok();
277 receiver
278 .receive()
279 .await
280 .ok_or("Channel receive failed")?
281 .map_err(|e| format!("Buffer map failed: {:?}", e))?;
282
283 let data = buffer_slice.get_mapped_range();
284 let result: Vec<f32> = bytemuck::cast_slice(&data).to_vec();
285 drop(data);
286 staging_buffer.unmap();
287
288 Ok(result.iter().sum())
290 }
291}