Skip to main content

runmat_plot/gpu/
axis.rs

1use crate::gpu::ScalarType;
2use std::sync::Arc;
3use wgpu::util::DeviceExt;
4
5/// Axis data source used by GPU plot packers.
6///
7/// For GPU kernels that support both f32 and f64 data paths, we treat axis precision as matching
8/// the input scalar type (f32 axes for f32 shaders, f64 axes for f64 shaders). Shaders can cast
9/// to f32 for final vertex positions as needed.
10#[derive(Clone)]
11pub enum AxisData<'a> {
12    F32(&'a [f32]),
13    F64(&'a [f64]),
14    Buffer(Arc<wgpu::Buffer>),
15}
16
17pub fn axis_storage_buffer(
18    device: &Arc<wgpu::Device>,
19    label: &'static str,
20    axis: &AxisData<'_>,
21    scalar: ScalarType,
22) -> Result<Arc<wgpu::Buffer>, String> {
23    match axis {
24        AxisData::Buffer(buffer) => Ok(buffer.clone()),
25        AxisData::F32(values) => {
26            if scalar != ScalarType::F32 {
27                return Err(format!("{label}: expected f64 axis data for f64 shader"));
28            }
29            Ok(Arc::new(device.create_buffer_init(
30                &wgpu::util::BufferInitDescriptor {
31                    label: Some(label),
32                    contents: bytemuck::cast_slice(values),
33                    usage: wgpu::BufferUsages::STORAGE | wgpu::BufferUsages::COPY_DST,
34                },
35            )))
36        }
37        AxisData::F64(values) => {
38            if scalar != ScalarType::F64 {
39                return Err(format!("{label}: expected f32 axis data for f32 shader"));
40            }
41            Ok(Arc::new(device.create_buffer_init(
42                &wgpu::util::BufferInitDescriptor {
43                    label: Some(label),
44                    contents: bytemuck::cast_slice(values),
45                    usage: wgpu::BufferUsages::STORAGE | wgpu::BufferUsages::COPY_DST,
46                },
47            )))
48        }
49    }
50}