msm_webgpu/cuzk/
gpu.rs

1use wgpu::{
2    Adapter, BindGroup, BindGroupDescriptor, BindGroupEntry, BindGroupLayout,
3    BindGroupLayoutDescriptor, BindGroupLayoutEntry, Buffer, BufferAsyncError, BufferDescriptor,
4    BufferSlice, BufferUsages, CommandEncoder, ComputePipeline, ComputePipelineDescriptor, Device,
5    Features, Instance, Limits, MapMode, MemoryHints, PipelineCompilationOptions,
6    PipelineLayoutDescriptor, PowerPreference, Queue, ShaderModuleDescriptor, ShaderSource,
7    util::{BufferInitDescriptor, DeviceExt},
8};
9
10/// Get an adapter
11pub async fn get_adapter() -> Adapter {
12    let instance = Instance::default();
13
14    // Request an adapter (the GPU) from the browser
15    instance
16        .request_adapter(&wgpu::RequestAdapterOptions {
17            compatible_surface: None,
18            power_preference: PowerPreference::HighPerformance,
19            force_fallback_adapter: false,
20        })
21        .await
22        .expect("No suitable GPU adapters found on the system!")
23}
24
25/// Get a device
26pub async fn get_device(adapter: &Adapter) -> (Device, Queue) {
27    let required_limits = Limits {
28        max_buffer_size: adapter.limits().max_buffer_size,
29        max_storage_buffer_binding_size: adapter.limits().max_storage_buffer_binding_size,
30        max_compute_workgroup_storage_size: adapter.limits().max_compute_workgroup_storage_size,
31        max_compute_workgroup_size_x: 1024,
32        max_compute_invocations_per_workgroup: 1024,
33        max_compute_workgroups_per_dimension: adapter.limits().max_compute_workgroups_per_dimension,
34        max_storage_buffers_per_shader_stage: adapter.limits().max_storage_buffers_per_shader_stage,
35        max_bind_groups: adapter.limits().max_bind_groups,
36        max_bindings_per_bind_group: adapter.limits().max_bindings_per_bind_group,
37        ..Default::default()
38    };
39
40    let (device, queue) = adapter
41        .request_device(
42            &wgpu::DeviceDescriptor {
43                label: None,
44                required_limits,
45                required_features: Features::empty(),
46                memory_hints: MemoryHints::default(), // Favor performance over memory usage
47            },
48            None,
49        )
50        .await
51        .expect("Could not create adapter for device");
52
53    (device, queue)
54}
55
56/// Create a storage buffer
57pub fn create_storage_buffer(label: Option<&str>, device: &Device, size: u64) -> Buffer {
58    device.create_buffer(&BufferDescriptor {
59        label,
60        size,
61        usage: BufferUsages::STORAGE | BufferUsages::COPY_SRC | BufferUsages::COPY_DST,
62        mapped_at_creation: false,
63    })
64}
65
66/// Create a storage buffer and write data to it
67pub fn create_and_write_storage_buffer(
68    label: Option<&str>,
69    device: &Device,
70    data: &[u8],
71) -> Buffer {
72    device.create_buffer_init(&BufferInitDescriptor {
73        label,
74        contents: data,
75        usage: BufferUsages::STORAGE | BufferUsages::COPY_DST | BufferUsages::COPY_SRC,
76    })
77}
78
79/// Create a uniform buffer and write data to it
80pub fn create_and_write_uniform_buffer(
81    label: Option<&str>,
82    device: &Device,
83    queue: &Queue,
84    data: &[u8],
85) -> Buffer {
86    let buffer = device.create_buffer(&BufferDescriptor {
87        label,
88        size: data.len() as u64,
89        usage: BufferUsages::UNIFORM | BufferUsages::COPY_DST,
90        mapped_at_creation: false,
91    });
92
93    queue.write_buffer(&buffer, 0, data);
94
95    buffer
96}
97
98/// Read data from the GPU
99pub async fn read_from_gpu(
100    device: &Device,
101    queue: &Queue,
102    mut encoder: CommandEncoder,
103    storage_buffers: Vec<Buffer>,
104) -> Vec<Vec<u8>> {
105    let mut staging_buffers = Vec::new();
106
107    for (i, storage_buffer) in storage_buffers.iter().enumerate() {
108        let size = storage_buffer.size();
109        let staging_buffer = device.create_buffer(&BufferDescriptor {
110            label: Some(&format!("Staging Buffer {i}")),
111            size,
112            usage: BufferUsages::MAP_READ | BufferUsages::COPY_DST,
113            mapped_at_creation: false,
114        });
115        encoder.copy_buffer_to_buffer(storage_buffer, 0, &staging_buffer, 0, size);
116        staging_buffers.push(staging_buffer);
117    }
118
119    let command_buffer = encoder.finish();
120
121    queue.submit(vec![command_buffer]);
122    device.poll(wgpu::Maintain::Wait);
123
124    let mut data = Vec::new();
125    for staging_buffer in staging_buffers {
126        let staging_slice = staging_buffer.slice(..);
127        let _buffer_future = map_buffer_async_browser(staging_slice, MapMode::Read).await;
128        device.poll(wgpu::Maintain::Wait);
129        let result_data = staging_slice.get_mapped_range();
130        data.push(result_data.to_vec());
131    }
132
133    data
134}
135
136/// Read data from the GPU for testing
137pub async fn read_from_gpu_test(
138    device: &Device,
139    queue: &Queue,
140    mut encoder: CommandEncoder,
141    storage_buffers: Vec<Buffer>,
142) -> Vec<Vec<u8>> {
143    let mut staging_buffers = Vec::new();
144
145    for (i, storage_buffer) in storage_buffers.iter().enumerate() {
146        let size = storage_buffer.size();
147        let staging_buffer = device.create_buffer(&BufferDescriptor {
148            label: Some(&format!("Staging Buffer {i}")),
149            size,
150            usage: BufferUsages::MAP_READ | BufferUsages::COPY_DST,
151            mapped_at_creation: false,
152        });
153        encoder.copy_buffer_to_buffer(storage_buffer, 0, &staging_buffer, 0, size);
154        staging_buffers.push(staging_buffer);
155    }
156
157    let command_buffer = encoder.finish();
158
159    queue.submit(vec![command_buffer]);
160
161    let mut data = Vec::new();
162    for staging_buffer in staging_buffers {
163        let staging_slice = staging_buffer.slice(..);
164        staging_slice.map_async(MapMode::Read, |x| x.unwrap());
165        device.poll(wgpu::Maintain::Wait);
166        let result_data = staging_slice.get_mapped_range();
167        data.push(result_data.to_vec());
168    }
169
170    data
171}
172
173/// Create a bind group layout
174pub fn create_bind_group_layout(
175    label: Option<&str>,
176    device: &Device,
177    storage_buffers_read_only: Vec<&Buffer>,
178    storage_buffers: Vec<&Buffer>,
179    uniform_buffers: Vec<&Buffer>,
180) -> BindGroupLayout {
181    let storage_buffer_read_only_entries = (0..storage_buffers_read_only.len())
182        .map(|i| default_storage_read_only_buffer_entry(i as u32))
183        .collect::<Vec<_>>();
184    let storage_buffer_entries = (0..storage_buffers.len())
185        .map(|i| default_storage_buffer_entry((i + storage_buffers_read_only.len()) as u32))
186        .collect::<Vec<_>>();
187
188    let uniform_buffer_entries = (0..uniform_buffers.len())
189        .map(|i| {
190            default_uniform_buffer_entry(
191                (i + storage_buffers.len() + storage_buffers_read_only.len()) as u32,
192            )
193        })
194        .collect::<Vec<_>>();
195    device.create_bind_group_layout(&BindGroupLayoutDescriptor {
196        label,
197        entries: &[storage_buffer_read_only_entries,
198            storage_buffer_entries,
199            uniform_buffer_entries]
200        .concat(),
201    })
202}
203
204/// Default storage buffer entry
205pub fn default_storage_buffer_entry(idx: u32) -> BindGroupLayoutEntry {
206    BindGroupLayoutEntry {
207        binding: idx,
208        visibility: wgpu::ShaderStages::COMPUTE,
209        ty: wgpu::BindingType::Buffer {
210            ty: wgpu::BufferBindingType::Storage { read_only: false },
211            has_dynamic_offset: false,
212            min_binding_size: None,
213        },
214        count: None,
215    }
216}
217
218/// Default storage read only buffer entry
219pub fn default_storage_read_only_buffer_entry(idx: u32) -> BindGroupLayoutEntry {
220    BindGroupLayoutEntry {
221        binding: idx,
222        visibility: wgpu::ShaderStages::COMPUTE,
223        ty: wgpu::BindingType::Buffer {
224            ty: wgpu::BufferBindingType::Storage { read_only: true },
225            has_dynamic_offset: false,
226            min_binding_size: None,
227        },
228        count: None,
229    }
230}
231
232/// Default uniform buffer entry
233pub fn default_uniform_buffer_entry(idx: u32) -> BindGroupLayoutEntry {
234    BindGroupLayoutEntry {
235        binding: idx,
236        visibility: wgpu::ShaderStages::COMPUTE,
237        ty: wgpu::BindingType::Buffer {
238            ty: wgpu::BufferBindingType::Uniform,
239            has_dynamic_offset: false,
240            min_binding_size: None,
241        },
242        count: None,
243    }
244}
245
246/// Create a bind group
247pub fn create_bind_group(
248    label: Option<&str>,
249    device: &Device,
250    bind_group_layout: &BindGroupLayout,
251    buffers: Vec<&Buffer>,
252) -> BindGroup {
253    device.create_bind_group(&BindGroupDescriptor {
254        label,
255        layout: bind_group_layout,
256        entries: &buffers
257            .iter()
258            .enumerate()
259            .map(|(i, buffer)| BindGroupEntry {
260                binding: i as u32,
261                resource: buffer.as_entire_binding(),
262            })
263            .collect::<Vec<_>>(),
264    })
265}
266
267/// Create a compute pipeline
268pub async fn create_compute_pipeline(
269    label: Option<&str>,
270    device: &Device,
271    bind_group_layout: &BindGroupLayout,
272    code: &str,
273    entry_point: &str,
274) -> ComputePipeline {
275    let pipeline_layout = device.create_pipeline_layout(&PipelineLayoutDescriptor {
276        label,
277        bind_group_layouts: &[bind_group_layout],
278        push_constant_ranges: &[],
279    });
280
281    let module = device.create_shader_module(ShaderModuleDescriptor {
282        label,
283        source: ShaderSource::Wgsl(code.into()),
284    });
285
286    device.create_compute_pipeline(&ComputePipelineDescriptor {
287        label,
288        layout: Some(&pipeline_layout),
289        module: &module,
290        entry_point: Some(entry_point),
291        compilation_options: PipelineCompilationOptions::default(),
292        cache: None,
293    })
294}
295
296/// Execute a compute pipeline
297pub async fn execute_pipeline(
298    encoder: &mut CommandEncoder,
299    pipeline: ComputePipeline,
300    bind_group: BindGroup,
301    num_x_workgroups: u32,
302    num_y_workgroups: u32,
303    num_z_workgroups: u32,
304) {
305    let mut cpass = encoder.begin_compute_pass(&wgpu::ComputePassDescriptor {
306        label: None,
307        timestamp_writes: None,
308    });
309    cpass.set_pipeline(&pipeline);
310    cpass.set_bind_group(0, &bind_group, &[]);
311    cpass.dispatch_workgroups(num_x_workgroups, num_y_workgroups, num_z_workgroups);
312}
313
314/// Map a buffer asynchronously
315pub fn map_buffer_async_browser(
316    slice: BufferSlice<'_>,
317    mode: MapMode,
318) -> impl std::future::Future<Output = Result<(), BufferAsyncError>> {
319    let (sender, receiver) = oneshot::channel();
320    slice.map_async(mode, move |res| {
321        let _ = sender.send(res);
322    });
323    async move {
324        match receiver.await {
325            Ok(result) => result,
326            Err(_) => Err(BufferAsyncError {}),
327        }
328    }
329}