1use wgpu::{
2 Adapter, BindGroup, BindGroupDescriptor, BindGroupEntry, BindGroupLayout,
3 BindGroupLayoutDescriptor, BindGroupLayoutEntry, Buffer, BufferAsyncError, BufferDescriptor,
4 BufferSlice, BufferUsages, CommandEncoder, ComputePipeline, ComputePipelineDescriptor, Device,
5 Features, Instance, Limits, MapMode, MemoryHints, PipelineCompilationOptions,
6 PipelineLayoutDescriptor, PowerPreference, Queue, ShaderModuleDescriptor, ShaderSource,
7 util::{BufferInitDescriptor, DeviceExt},
8};
9
10pub async fn get_adapter() -> Adapter {
12 let instance = Instance::default();
13
14 instance
16 .request_adapter(&wgpu::RequestAdapterOptions {
17 compatible_surface: None,
18 power_preference: PowerPreference::HighPerformance,
19 force_fallback_adapter: false,
20 })
21 .await
22 .expect("No suitable GPU adapters found on the system!")
23}
24
25pub async fn get_device(adapter: &Adapter) -> (Device, Queue) {
27 let required_limits = Limits {
28 max_buffer_size: adapter.limits().max_buffer_size,
29 max_storage_buffer_binding_size: adapter.limits().max_storage_buffer_binding_size,
30 max_compute_workgroup_storage_size: adapter.limits().max_compute_workgroup_storage_size,
31 max_compute_workgroup_size_x: 1024,
32 max_compute_invocations_per_workgroup: 1024,
33 max_compute_workgroups_per_dimension: adapter.limits().max_compute_workgroups_per_dimension,
34 max_storage_buffers_per_shader_stage: adapter.limits().max_storage_buffers_per_shader_stage,
35 max_bind_groups: adapter.limits().max_bind_groups,
36 max_bindings_per_bind_group: adapter.limits().max_bindings_per_bind_group,
37 ..Default::default()
38 };
39
40 let (device, queue) = adapter
41 .request_device(
42 &wgpu::DeviceDescriptor {
43 label: None,
44 required_limits,
45 required_features: Features::empty(),
46 memory_hints: MemoryHints::default(), },
48 None,
49 )
50 .await
51 .expect("Could not create adapter for device");
52
53 (device, queue)
54}
55
56pub fn create_storage_buffer(label: Option<&str>, device: &Device, size: u64) -> Buffer {
58 device.create_buffer(&BufferDescriptor {
59 label,
60 size,
61 usage: BufferUsages::STORAGE | BufferUsages::COPY_SRC | BufferUsages::COPY_DST,
62 mapped_at_creation: false,
63 })
64}
65
66pub fn create_and_write_storage_buffer(
68 label: Option<&str>,
69 device: &Device,
70 data: &[u8],
71) -> Buffer {
72 device.create_buffer_init(&BufferInitDescriptor {
73 label,
74 contents: data,
75 usage: BufferUsages::STORAGE | BufferUsages::COPY_DST | BufferUsages::COPY_SRC,
76 })
77}
78
79pub fn create_and_write_uniform_buffer(
81 label: Option<&str>,
82 device: &Device,
83 queue: &Queue,
84 data: &[u8],
85) -> Buffer {
86 let buffer = device.create_buffer(&BufferDescriptor {
87 label,
88 size: data.len() as u64,
89 usage: BufferUsages::UNIFORM | BufferUsages::COPY_DST,
90 mapped_at_creation: false,
91 });
92
93 queue.write_buffer(&buffer, 0, data);
94
95 buffer
96}
97
98pub async fn read_from_gpu(
100 device: &Device,
101 queue: &Queue,
102 mut encoder: CommandEncoder,
103 storage_buffers: Vec<Buffer>,
104) -> Vec<Vec<u8>> {
105 let mut staging_buffers = Vec::new();
106
107 for (i, storage_buffer) in storage_buffers.iter().enumerate() {
108 let size = storage_buffer.size();
109 let staging_buffer = device.create_buffer(&BufferDescriptor {
110 label: Some(&format!("Staging Buffer {i}")),
111 size,
112 usage: BufferUsages::MAP_READ | BufferUsages::COPY_DST,
113 mapped_at_creation: false,
114 });
115 encoder.copy_buffer_to_buffer(storage_buffer, 0, &staging_buffer, 0, size);
116 staging_buffers.push(staging_buffer);
117 }
118
119 let command_buffer = encoder.finish();
120
121 queue.submit(vec![command_buffer]);
122 device.poll(wgpu::Maintain::Wait);
123
124 let mut data = Vec::new();
125 for staging_buffer in staging_buffers {
126 let staging_slice = staging_buffer.slice(..);
127 let _buffer_future = map_buffer_async_browser(staging_slice, MapMode::Read).await;
128 device.poll(wgpu::Maintain::Wait);
129 let result_data = staging_slice.get_mapped_range();
130 data.push(result_data.to_vec());
131 }
132
133 data
134}
135
136pub async fn read_from_gpu_test(
138 device: &Device,
139 queue: &Queue,
140 mut encoder: CommandEncoder,
141 storage_buffers: Vec<Buffer>,
142) -> Vec<Vec<u8>> {
143 let mut staging_buffers = Vec::new();
144
145 for (i, storage_buffer) in storage_buffers.iter().enumerate() {
146 let size = storage_buffer.size();
147 let staging_buffer = device.create_buffer(&BufferDescriptor {
148 label: Some(&format!("Staging Buffer {i}")),
149 size,
150 usage: BufferUsages::MAP_READ | BufferUsages::COPY_DST,
151 mapped_at_creation: false,
152 });
153 encoder.copy_buffer_to_buffer(storage_buffer, 0, &staging_buffer, 0, size);
154 staging_buffers.push(staging_buffer);
155 }
156
157 let command_buffer = encoder.finish();
158
159 queue.submit(vec![command_buffer]);
160
161 let mut data = Vec::new();
162 for staging_buffer in staging_buffers {
163 let staging_slice = staging_buffer.slice(..);
164 staging_slice.map_async(MapMode::Read, |x| x.unwrap());
165 device.poll(wgpu::Maintain::Wait);
166 let result_data = staging_slice.get_mapped_range();
167 data.push(result_data.to_vec());
168 }
169
170 data
171}
172
173pub fn create_bind_group_layout(
175 label: Option<&str>,
176 device: &Device,
177 storage_buffers_read_only: Vec<&Buffer>,
178 storage_buffers: Vec<&Buffer>,
179 uniform_buffers: Vec<&Buffer>,
180) -> BindGroupLayout {
181 let storage_buffer_read_only_entries = (0..storage_buffers_read_only.len())
182 .map(|i| default_storage_read_only_buffer_entry(i as u32))
183 .collect::<Vec<_>>();
184 let storage_buffer_entries = (0..storage_buffers.len())
185 .map(|i| default_storage_buffer_entry((i + storage_buffers_read_only.len()) as u32))
186 .collect::<Vec<_>>();
187
188 let uniform_buffer_entries = (0..uniform_buffers.len())
189 .map(|i| {
190 default_uniform_buffer_entry(
191 (i + storage_buffers.len() + storage_buffers_read_only.len()) as u32,
192 )
193 })
194 .collect::<Vec<_>>();
195 device.create_bind_group_layout(&BindGroupLayoutDescriptor {
196 label,
197 entries: &[storage_buffer_read_only_entries,
198 storage_buffer_entries,
199 uniform_buffer_entries]
200 .concat(),
201 })
202}
203
204pub fn default_storage_buffer_entry(idx: u32) -> BindGroupLayoutEntry {
206 BindGroupLayoutEntry {
207 binding: idx,
208 visibility: wgpu::ShaderStages::COMPUTE,
209 ty: wgpu::BindingType::Buffer {
210 ty: wgpu::BufferBindingType::Storage { read_only: false },
211 has_dynamic_offset: false,
212 min_binding_size: None,
213 },
214 count: None,
215 }
216}
217
218pub fn default_storage_read_only_buffer_entry(idx: u32) -> BindGroupLayoutEntry {
220 BindGroupLayoutEntry {
221 binding: idx,
222 visibility: wgpu::ShaderStages::COMPUTE,
223 ty: wgpu::BindingType::Buffer {
224 ty: wgpu::BufferBindingType::Storage { read_only: true },
225 has_dynamic_offset: false,
226 min_binding_size: None,
227 },
228 count: None,
229 }
230}
231
232pub fn default_uniform_buffer_entry(idx: u32) -> BindGroupLayoutEntry {
234 BindGroupLayoutEntry {
235 binding: idx,
236 visibility: wgpu::ShaderStages::COMPUTE,
237 ty: wgpu::BindingType::Buffer {
238 ty: wgpu::BufferBindingType::Uniform,
239 has_dynamic_offset: false,
240 min_binding_size: None,
241 },
242 count: None,
243 }
244}
245
246pub fn create_bind_group(
248 label: Option<&str>,
249 device: &Device,
250 bind_group_layout: &BindGroupLayout,
251 buffers: Vec<&Buffer>,
252) -> BindGroup {
253 device.create_bind_group(&BindGroupDescriptor {
254 label,
255 layout: bind_group_layout,
256 entries: &buffers
257 .iter()
258 .enumerate()
259 .map(|(i, buffer)| BindGroupEntry {
260 binding: i as u32,
261 resource: buffer.as_entire_binding(),
262 })
263 .collect::<Vec<_>>(),
264 })
265}
266
267pub async fn create_compute_pipeline(
269 label: Option<&str>,
270 device: &Device,
271 bind_group_layout: &BindGroupLayout,
272 code: &str,
273 entry_point: &str,
274) -> ComputePipeline {
275 let pipeline_layout = device.create_pipeline_layout(&PipelineLayoutDescriptor {
276 label,
277 bind_group_layouts: &[bind_group_layout],
278 push_constant_ranges: &[],
279 });
280
281 let module = device.create_shader_module(ShaderModuleDescriptor {
282 label,
283 source: ShaderSource::Wgsl(code.into()),
284 });
285
286 device.create_compute_pipeline(&ComputePipelineDescriptor {
287 label,
288 layout: Some(&pipeline_layout),
289 module: &module,
290 entry_point: Some(entry_point),
291 compilation_options: PipelineCompilationOptions::default(),
292 cache: None,
293 })
294}
295
296pub async fn execute_pipeline(
298 encoder: &mut CommandEncoder,
299 pipeline: ComputePipeline,
300 bind_group: BindGroup,
301 num_x_workgroups: u32,
302 num_y_workgroups: u32,
303 num_z_workgroups: u32,
304) {
305 let mut cpass = encoder.begin_compute_pass(&wgpu::ComputePassDescriptor {
306 label: None,
307 timestamp_writes: None,
308 });
309 cpass.set_pipeline(&pipeline);
310 cpass.set_bind_group(0, &bind_group, &[]);
311 cpass.dispatch_workgroups(num_x_workgroups, num_y_workgroups, num_z_workgroups);
312}
313
314pub fn map_buffer_async_browser(
316 slice: BufferSlice<'_>,
317 mode: MapMode,
318) -> impl std::future::Future<Output = Result<(), BufferAsyncError>> {
319 let (sender, receiver) = oneshot::channel();
320 slice.map_async(mode, move |res| {
321 let _ = sender.send(res);
322 });
323 async move {
324 match receiver.await {
325 Ok(result) => result,
326 Err(_) => Err(BufferAsyncError {}),
327 }
328 }
329}