Skip to main content

runmat_plot/gpu/
util.rs

1use crate::gpu::ScalarType;
2use bytemuck::cast_slice;
3use futures::channel::oneshot;
4use std::sync::Arc;
5
6pub(crate) async fn map_read_async(
7    device: &wgpu::Device,
8    slice: &wgpu::BufferSlice<'_>,
9) -> Result<(), String> {
10    let (tx, rx) = oneshot::channel();
11    slice.map_async(wgpu::MapMode::Read, move |result| {
12        let _ = tx.send(result);
13    });
14
15    #[cfg(not(target_arch = "wasm32"))]
16    device.poll(wgpu::Maintain::Wait);
17
18    #[cfg(target_arch = "wasm32")]
19    device.poll(wgpu::Maintain::Poll);
20
21    rx.await
22        .map_err(|_| "map failed".to_string())?
23        .map_err(|_| "map error".to_string())?;
24    Ok(())
25}
26
27pub async fn readback_u32(
28    device: &Arc<wgpu::Device>,
29    buffer: &wgpu::Buffer,
30) -> Result<u32, String> {
31    let slice = buffer.slice(..);
32    map_read_async(device, &slice).await?;
33    let data = slice.get_mapped_range();
34    if data.len() < std::mem::size_of::<u32>() {
35        drop(data);
36        buffer.unmap();
37        return Err("readback buffer too small".to_string());
38    }
39    let mut bytes = [0u8; 4];
40    bytes.copy_from_slice(&data[..4]);
41    drop(data);
42    buffer.unmap();
43    Ok(u32::from_le_bytes(bytes))
44}
45
46pub async fn readback_f32(
47    device: &Arc<wgpu::Device>,
48    buffer: &wgpu::Buffer,
49) -> Result<f32, String> {
50    let bits = readback_u32(device, buffer).await?;
51    Ok(f32::from_bits(bits))
52}
53
54pub async fn readback_f32_buffer(
55    device: &Arc<wgpu::Device>,
56    buffer: &wgpu::Buffer,
57    element_count: usize,
58) -> Result<Vec<f32>, String> {
59    if element_count == 0 {
60        return Ok(Vec::new());
61    }
62    let byte_len = element_count * std::mem::size_of::<f32>();
63    let slice = buffer.slice(0..byte_len as u64);
64    map_read_async(device, &slice).await?;
65    let data = slice.get_mapped_range();
66    if data.len() < byte_len {
67        drop(data);
68        buffer.unmap();
69        return Err("GPU readback buffer too small".to_string());
70    }
71    let floats: &[f32] = cast_slice(&data[..byte_len]);
72    let out = floats.to_vec();
73    drop(data);
74    buffer.unmap();
75    Ok(out)
76}
77
78pub async fn copy_readback_bytes(
79    device: &Arc<wgpu::Device>,
80    queue: &Arc<wgpu::Queue>,
81    buffer: &wgpu::Buffer,
82    byte_len: usize,
83) -> Result<Vec<u8>, String> {
84    if byte_len == 0 {
85        return Ok(Vec::new());
86    }
87
88    let staging = device.create_buffer(&wgpu::BufferDescriptor {
89        label: Some("runmat-plot-readback-staging"),
90        size: byte_len as u64,
91        usage: wgpu::BufferUsages::MAP_READ | wgpu::BufferUsages::COPY_DST,
92        mapped_at_creation: false,
93    });
94    let mut encoder = device.create_command_encoder(&wgpu::CommandEncoderDescriptor {
95        label: Some("runmat-plot-readback-encoder"),
96    });
97    encoder.copy_buffer_to_buffer(buffer, 0, &staging, 0, byte_len as u64);
98    queue.submit(Some(encoder.finish()));
99
100    let slice = staging.slice(0..byte_len as u64);
101    map_read_async(device, &slice).await?;
102    let data = slice.get_mapped_range();
103    if data.len() < byte_len {
104        drop(data);
105        staging.unmap();
106        return Err("GPU readback staging buffer too small".to_string());
107    }
108    let out = data[..byte_len].to_vec();
109    drop(data);
110    staging.unmap();
111    Ok(out)
112}
113
114pub async fn readback_scalar_buffer_f64(
115    device: &Arc<wgpu::Device>,
116    queue: &Arc<wgpu::Queue>,
117    buffer: &wgpu::Buffer,
118    element_count: usize,
119    scalar: ScalarType,
120) -> Result<Vec<f64>, String> {
121    if element_count == 0 {
122        return Ok(Vec::new());
123    }
124
125    match scalar {
126        ScalarType::F32 => {
127            let byte_len = element_count * std::mem::size_of::<f32>();
128            let bytes = copy_readback_bytes(device, queue, buffer, byte_len).await?;
129            let values: &[f32] = cast_slice(&bytes);
130            Ok(values.iter().map(|value| f64::from(*value)).collect())
131        }
132        ScalarType::F64 => {
133            let byte_len = element_count * std::mem::size_of::<f64>();
134            let bytes = copy_readback_bytes(device, queue, buffer, byte_len).await?;
135            let values: &[f64] = cast_slice(&bytes);
136            Ok(values.to_vec())
137        }
138    }
139}