Skip to main content

runmat_plot/gpu/
axis.rs

1use crate::gpu::ScalarType;
2use std::sync::Arc;
3use wgpu::util::DeviceExt;
4
5/// Axis data source used by GPU plot packers.
6///
7/// For GPU kernels that support both f32 and f64 data paths, we treat axis precision as matching
8/// the input scalar type (f32 axes for f32 shaders, f64 axes for f64 shaders). Shaders can cast
9/// to f32 for final vertex positions as needed.
10#[derive(Clone)]
11pub enum AxisData<'a> {
12    F32(&'a [f32]),
13    F64(&'a [f64]),
14    Buffer(Arc<wgpu::Buffer>),
15}
16
17#[derive(Clone, Debug)]
18pub enum OwnedAxisData {
19    F32(Vec<f32>),
20    F64(Vec<f64>),
21    Buffer(Arc<wgpu::Buffer>),
22}
23
24impl OwnedAxisData {
25    pub fn from_axis(axis: &AxisData<'_>) -> Self {
26        match axis {
27            AxisData::F32(values) => Self::F32(values.to_vec()),
28            AxisData::F64(values) => Self::F64(values.to_vec()),
29            AxisData::Buffer(buffer) => Self::Buffer(buffer.clone()),
30        }
31    }
32
33    pub async fn export_f64(
34        &self,
35        device: &Arc<wgpu::Device>,
36        queue: &Arc<wgpu::Queue>,
37        len: usize,
38        scalar: ScalarType,
39    ) -> Result<Vec<f64>, String> {
40        match self {
41            Self::F32(values) => Ok(values.iter().map(|value| f64::from(*value)).collect()),
42            Self::F64(values) => Ok(values.clone()),
43            Self::Buffer(buffer) => {
44                crate::gpu::util::readback_scalar_buffer_f64(device, queue, buffer, len, scalar)
45                    .await
46            }
47        }
48    }
49}
50
51pub fn axis_storage_buffer(
52    device: &Arc<wgpu::Device>,
53    label: &'static str,
54    axis: &AxisData<'_>,
55    scalar: ScalarType,
56) -> Result<Arc<wgpu::Buffer>, String> {
57    match axis {
58        AxisData::Buffer(buffer) => Ok(buffer.clone()),
59        AxisData::F32(values) => {
60            if scalar != ScalarType::F32 {
61                return Err(format!("{label}: expected f64 axis data for f64 shader"));
62            }
63            Ok(Arc::new(device.create_buffer_init(
64                &wgpu::util::BufferInitDescriptor {
65                    label: Some(label),
66                    contents: bytemuck::cast_slice(values),
67                    usage: wgpu::BufferUsages::STORAGE
68                        | wgpu::BufferUsages::COPY_DST
69                        | wgpu::BufferUsages::COPY_SRC,
70                },
71            )))
72        }
73        AxisData::F64(values) => {
74            if scalar != ScalarType::F64 {
75                return Err(format!("{label}: expected f32 axis data for f32 shader"));
76            }
77            Ok(Arc::new(device.create_buffer_init(
78                &wgpu::util::BufferInitDescriptor {
79                    label: Some(label),
80                    contents: bytemuck::cast_slice(values),
81                    usage: wgpu::BufferUsages::STORAGE
82                        | wgpu::BufferUsages::COPY_DST
83                        | wgpu::BufferUsages::COPY_SRC,
84                },
85            )))
86        }
87    }
88}