Skip to main content

runmat_plot/gpu/
util.rs

1use bytemuck::cast_slice;
2use futures::channel::oneshot;
3use std::sync::Arc;
4
5pub(crate) async fn map_read_async(
6    device: &wgpu::Device,
7    slice: &wgpu::BufferSlice<'_>,
8) -> Result<(), String> {
9    let (tx, rx) = oneshot::channel();
10    slice.map_async(wgpu::MapMode::Read, move |result| {
11        let _ = tx.send(result);
12    });
13
14    #[cfg(not(target_arch = "wasm32"))]
15    device.poll(wgpu::Maintain::Wait);
16
17    #[cfg(target_arch = "wasm32")]
18    device.poll(wgpu::Maintain::Poll);
19
20    rx.await
21        .map_err(|_| "map failed".to_string())?
22        .map_err(|_| "map error".to_string())?;
23    Ok(())
24}
25
26pub async fn readback_u32(
27    device: &Arc<wgpu::Device>,
28    buffer: &wgpu::Buffer,
29) -> Result<u32, String> {
30    let slice = buffer.slice(..);
31    map_read_async(device, &slice).await?;
32    let data = slice.get_mapped_range();
33    if data.len() < std::mem::size_of::<u32>() {
34        drop(data);
35        buffer.unmap();
36        return Err("readback buffer too small".to_string());
37    }
38    let mut bytes = [0u8; 4];
39    bytes.copy_from_slice(&data[..4]);
40    drop(data);
41    buffer.unmap();
42    Ok(u32::from_le_bytes(bytes))
43}
44
45pub async fn readback_f32(
46    device: &Arc<wgpu::Device>,
47    buffer: &wgpu::Buffer,
48) -> Result<f32, String> {
49    let bits = readback_u32(device, buffer).await?;
50    Ok(f32::from_bits(bits))
51}
52
53pub async fn readback_f32_buffer(
54    device: &Arc<wgpu::Device>,
55    buffer: &wgpu::Buffer,
56    element_count: usize,
57) -> Result<Vec<f32>, String> {
58    if element_count == 0 {
59        return Ok(Vec::new());
60    }
61    let byte_len = element_count * std::mem::size_of::<f32>();
62    let slice = buffer.slice(0..byte_len as u64);
63    map_read_async(device, &slice).await?;
64    let data = slice.get_mapped_range();
65    if data.len() < byte_len {
66        drop(data);
67        buffer.unmap();
68        return Err("GPU readback buffer too small".to_string());
69    }
70    let floats: &[f32] = cast_slice(&data[..byte_len]);
71    let out = floats.to_vec();
72    drop(data);
73    buffer.unmap();
74    Ok(out)
75}