trueno/backends/gpu/device/reductions/
tiled_2d.rs1use super::super::GpuDevice;
6
7impl GpuDevice {
8 #[allow(clippy::too_many_arguments)]
10 pub(super) async fn tiled_reduce_2d_async<F>(
11 &self,
12 data: &[f32],
13 width: usize,
14 height: usize,
15 shader_source: &str,
16 op_name: &str,
17 identity: f32,
18 combine: F,
19 ) -> Result<f32, String>
20 where
21 F: Fn(&[f32]) -> f32,
22 {
23 if data.is_empty() || width == 0 || height == 0 {
24 return Ok(identity);
25 }
26
27 let workgroup_size_x: u32 = 16;
29 let workgroup_size_y: u32 = 16;
30 let num_workgroups_x = (width as u32).div_ceil(workgroup_size_x);
31 let num_workgroups_y = (height as u32).div_ceil(workgroup_size_y);
32 let total_workgroups = (num_workgroups_x * num_workgroups_y) as usize;
33
34 let shader = self.device.create_shader_module(wgpu::ShaderModuleDescriptor {
36 label: Some(&format!("{} Shader", op_name)),
37 source: wgpu::ShaderSource::Wgsl(shader_source.into()),
38 });
39
40 let input_buffer = self.device.create_buffer(&wgpu::BufferDescriptor {
42 label: Some(&format!("{} Input", op_name)),
43 size: std::mem::size_of_val(data) as u64,
44 usage: wgpu::BufferUsages::STORAGE | wgpu::BufferUsages::COPY_DST,
45 mapped_at_creation: false,
46 });
47
48 let partial_buffer = self.device.create_buffer(&wgpu::BufferDescriptor {
50 label: Some(&format!("{} Partial Results", op_name)),
51 size: (total_workgroups * std::mem::size_of::<f32>()) as u64,
52 usage: wgpu::BufferUsages::STORAGE
53 | wgpu::BufferUsages::COPY_SRC
54 | wgpu::BufferUsages::COPY_DST,
55 mapped_at_creation: false,
56 });
57
58 #[repr(C)]
60 #[derive(Copy, Clone, bytemuck::Pod, bytemuck::Zeroable)]
61 struct Dimensions {
62 width: u32,
63 height: u32,
64 }
65
66 let dims = Dimensions { width: width as u32, height: height as u32 };
67
68 let dims_buffer = self.device.create_buffer(&wgpu::BufferDescriptor {
69 label: Some(&format!("{} Dimensions", op_name)),
70 size: std::mem::size_of::<Dimensions>() as u64,
71 usage: wgpu::BufferUsages::UNIFORM | wgpu::BufferUsages::COPY_DST,
72 mapped_at_creation: false,
73 });
74
75 self.queue.write_buffer(&input_buffer, 0, bytemuck::cast_slice(data));
77 self.queue.write_buffer(&dims_buffer, 0, bytemuck::bytes_of(&dims));
78
79 let bind_group_layout =
81 self.device.create_bind_group_layout(&wgpu::BindGroupLayoutDescriptor {
82 label: Some(&format!("{} Bind Group Layout", op_name)),
83 entries: &[
84 wgpu::BindGroupLayoutEntry {
85 binding: 0,
86 visibility: wgpu::ShaderStages::COMPUTE,
87 ty: wgpu::BindingType::Buffer {
88 ty: wgpu::BufferBindingType::Storage { read_only: true },
89 has_dynamic_offset: false,
90 min_binding_size: None,
91 },
92 count: None,
93 },
94 wgpu::BindGroupLayoutEntry {
95 binding: 1,
96 visibility: wgpu::ShaderStages::COMPUTE,
97 ty: wgpu::BindingType::Buffer {
98 ty: wgpu::BufferBindingType::Storage { read_only: false },
99 has_dynamic_offset: false,
100 min_binding_size: None,
101 },
102 count: None,
103 },
104 wgpu::BindGroupLayoutEntry {
105 binding: 2,
106 visibility: wgpu::ShaderStages::COMPUTE,
107 ty: wgpu::BindingType::Buffer {
108 ty: wgpu::BufferBindingType::Uniform,
109 has_dynamic_offset: false,
110 min_binding_size: None,
111 },
112 count: None,
113 },
114 ],
115 });
116
117 let bind_group = self.device.create_bind_group(&wgpu::BindGroupDescriptor {
119 label: Some(&format!("{} Bind Group", op_name)),
120 layout: &bind_group_layout,
121 entries: &[
122 wgpu::BindGroupEntry { binding: 0, resource: input_buffer.as_entire_binding() },
123 wgpu::BindGroupEntry { binding: 1, resource: partial_buffer.as_entire_binding() },
124 wgpu::BindGroupEntry { binding: 2, resource: dims_buffer.as_entire_binding() },
125 ],
126 });
127
128 let pipeline_layout = self.device.create_pipeline_layout(&wgpu::PipelineLayoutDescriptor {
130 label: Some(&format!("{} Pipeline Layout", op_name)),
131 bind_group_layouts: &[&bind_group_layout],
132 push_constant_ranges: &[],
133 });
134
135 let pipeline = self.device.create_compute_pipeline(&wgpu::ComputePipelineDescriptor {
136 label: Some(&format!("{} Pipeline", op_name)),
137 layout: Some(&pipeline_layout),
138 module: &shader,
139 entry_point: Some("main"),
140 compilation_options: Default::default(),
141 cache: None,
142 });
143
144 let staging_buffer = self.device.create_buffer(&wgpu::BufferDescriptor {
146 label: Some(&format!("{} Staging", op_name)),
147 size: (total_workgroups * std::mem::size_of::<f32>()) as u64,
148 usage: wgpu::BufferUsages::MAP_READ | wgpu::BufferUsages::COPY_DST,
149 mapped_at_creation: false,
150 });
151
152 let mut encoder = self.device.create_command_encoder(&wgpu::CommandEncoderDescriptor {
154 label: Some(&format!("{} Encoder", op_name)),
155 });
156
157 {
158 let mut compute_pass = encoder.begin_compute_pass(&wgpu::ComputePassDescriptor {
159 label: Some(&format!("{} Pass", op_name)),
160 timestamp_writes: None,
161 });
162 compute_pass.set_pipeline(&pipeline);
163 compute_pass.set_bind_group(0, &bind_group, &[]);
164 compute_pass.dispatch_workgroups(num_workgroups_x, num_workgroups_y, 1);
165 }
166
167 encoder.copy_buffer_to_buffer(
169 &partial_buffer,
170 0,
171 &staging_buffer,
172 0,
173 (total_workgroups * std::mem::size_of::<f32>()) as u64,
174 );
175
176 self.queue.submit(Some(encoder.finish()));
178
179 let buffer_slice = staging_buffer.slice(..);
181 let (sender, receiver) = futures_intrusive::channel::shared::oneshot_channel();
182 buffer_slice.map_async(wgpu::MapMode::Read, move |result| {
183 sender.send(result).ok();
184 });
185
186 self.device.poll(wgpu::PollType::Wait { submission_index: None, timeout: None }).ok();
188
189 receiver
190 .receive()
191 .await
192 .ok_or("Failed to receive mapping result")?
193 .map_err(|e| format!("Buffer mapping failed: {:?}", e))?;
194
195 let final_result = {
196 let data = buffer_slice.get_mapped_range();
197 let partials: &[f32] = bytemuck::cast_slice(&data);
198 combine(partials)
199 };
200
201 staging_buffer.unmap();
202
203 Ok(final_result)
204 }
205}