1use crate::gpu::ScalarType;
2use std::sync::Arc;
3use wgpu::util::DeviceExt;
4
5#[derive(Clone)]
11pub enum AxisData<'a> {
12 F32(&'a [f32]),
13 F64(&'a [f64]),
14 Buffer(Arc<wgpu::Buffer>),
15}
16
17#[derive(Clone, Debug)]
18pub enum OwnedAxisData {
19 F32(Vec<f32>),
20 F64(Vec<f64>),
21 Buffer(Arc<wgpu::Buffer>),
22}
23
24impl OwnedAxisData {
25 pub fn from_axis(axis: &AxisData<'_>) -> Self {
26 match axis {
27 AxisData::F32(values) => Self::F32(values.to_vec()),
28 AxisData::F64(values) => Self::F64(values.to_vec()),
29 AxisData::Buffer(buffer) => Self::Buffer(buffer.clone()),
30 }
31 }
32
33 pub async fn export_f64(
34 &self,
35 device: &Arc<wgpu::Device>,
36 queue: &Arc<wgpu::Queue>,
37 len: usize,
38 scalar: ScalarType,
39 ) -> Result<Vec<f64>, String> {
40 match self {
41 Self::F32(values) => Ok(values.iter().map(|value| f64::from(*value)).collect()),
42 Self::F64(values) => Ok(values.clone()),
43 Self::Buffer(buffer) => {
44 crate::gpu::util::readback_scalar_buffer_f64(device, queue, buffer, len, scalar)
45 .await
46 }
47 }
48 }
49}
50
51pub fn axis_storage_buffer(
52 device: &Arc<wgpu::Device>,
53 label: &'static str,
54 axis: &AxisData<'_>,
55 scalar: ScalarType,
56) -> Result<Arc<wgpu::Buffer>, String> {
57 match axis {
58 AxisData::Buffer(buffer) => Ok(buffer.clone()),
59 AxisData::F32(values) => {
60 if scalar != ScalarType::F32 {
61 return Err(format!("{label}: expected f64 axis data for f64 shader"));
62 }
63 Ok(Arc::new(device.create_buffer_init(
64 &wgpu::util::BufferInitDescriptor {
65 label: Some(label),
66 contents: bytemuck::cast_slice(values),
67 usage: wgpu::BufferUsages::STORAGE
68 | wgpu::BufferUsages::COPY_DST
69 | wgpu::BufferUsages::COPY_SRC,
70 },
71 )))
72 }
73 AxisData::F64(values) => {
74 if scalar != ScalarType::F64 {
75 return Err(format!("{label}: expected f32 axis data for f32 shader"));
76 }
77 Ok(Arc::new(device.create_buffer_init(
78 &wgpu::util::BufferInitDescriptor {
79 label: Some(label),
80 contents: bytemuck::cast_slice(values),
81 usage: wgpu::BufferUsages::STORAGE
82 | wgpu::BufferUsages::COPY_DST
83 | wgpu::BufferUsages::COPY_SRC,
84 },
85 )))
86 }
87 }
88}