Skip to main content

runmat_plot/gpu/
surface.rs

1use crate::core::renderer::Vertex;
2use crate::core::scene::GpuVertexBuffer;
3use crate::gpu::axis::{axis_storage_buffer, AxisData};
4use crate::gpu::shaders;
5use crate::gpu::{tuning, ScalarType};
6use std::sync::Arc;
7use wgpu::util::DeviceExt;
8
9/// Axis data source used by the GPU surface vertex packer.
10pub type SurfaceAxis<'a> = AxisData<'a>;
11
12/// Inputs required to pack surface vertices directly on the GPU.
13pub struct SurfaceGpuInputs<'a> {
14    pub x_axis: SurfaceAxis<'a>,
15    pub y_axis: SurfaceAxis<'a>,
16    pub z_buffer: Arc<wgpu::Buffer>,
17    pub color_table: &'a [[f32; 4]],
18    pub x_len: u32,
19    pub y_len: u32,
20    pub scalar: ScalarType,
21}
22
23/// Parameters describing how the GPU vertices should be generated.
24pub struct SurfaceGpuParams {
25    pub min_z: f32,
26    pub max_z: f32,
27    pub alpha: f32,
28    pub flatten_z: bool,
29    pub x_stride: u32,
30    pub y_stride: u32,
31    pub lod_x_len: u32,
32    pub lod_y_len: u32,
33}
34
35#[repr(C)]
36#[derive(Clone, Copy, bytemuck::Pod, bytemuck::Zeroable)]
37struct SurfaceUniforms {
38    min_z: f32,
39    max_z: f32,
40    alpha: f32,
41    flatten: u32,
42    x_len: u32,
43    y_len: u32,
44    lod_x_len: u32,
45    lod_y_len: u32,
46    x_stride: u32,
47    y_stride: u32,
48    color_table_len: u32,
49    _pad: u32,
50}
51
52/// Builds a GPU-resident vertex buffer for surface plots directly from provider-owned Z data.
53pub fn pack_surface_vertices(
54    device: &Arc<wgpu::Device>,
55    queue: &Arc<wgpu::Queue>,
56    inputs: &SurfaceGpuInputs<'_>,
57    params: &SurfaceGpuParams,
58) -> Result<GpuVertexBuffer, String> {
59    if inputs.x_len < 2 || inputs.y_len < 2 {
60        return Err("surf: axis vectors must contain at least two elements".to_string());
61    }
62
63    let workgroup_size = tuning::effective_workgroup_size();
64    let shader = compile_shader(device, workgroup_size, inputs.scalar);
65
66    let x_buffer = axis_storage_buffer(device, "surface-x-axis", &inputs.x_axis, inputs.scalar)?;
67    let y_buffer = axis_storage_buffer(device, "surface-y-axis", &inputs.y_axis, inputs.scalar)?;
68
69    let color_buffer = Arc::new(
70        device.create_buffer_init(&wgpu::util::BufferInitDescriptor {
71            label: Some("surface-color-table"),
72            contents: bytemuck::cast_slice(inputs.color_table),
73            usage: wgpu::BufferUsages::STORAGE | wgpu::BufferUsages::COPY_DST,
74        }),
75    );
76
77    let lod_x_len = params.lod_x_len.max(1);
78    let lod_y_len = params.lod_y_len.max(1);
79    let vertex_count = lod_x_len
80        .checked_mul(lod_y_len)
81        .ok_or_else(|| "surf: grid dimensions overflowed vertex count".to_string())?;
82    let output_size = vertex_count as u64 * std::mem::size_of::<Vertex>() as u64;
83    let output_buffer = Arc::new(device.create_buffer(&wgpu::BufferDescriptor {
84        label: Some("surface-gpu-vertices"),
85        size: output_size,
86        usage: wgpu::BufferUsages::STORAGE
87            | wgpu::BufferUsages::VERTEX
88            | wgpu::BufferUsages::COPY_DST,
89        mapped_at_creation: false,
90    }));
91
92    let uniforms = SurfaceUniforms {
93        min_z: params.min_z,
94        max_z: params.max_z.max(params.min_z + 1e-6),
95        alpha: params.alpha,
96        flatten: if params.flatten_z { 1 } else { 0 },
97        x_len: inputs.x_len,
98        y_len: inputs.y_len,
99        lod_x_len,
100        lod_y_len,
101        x_stride: params.x_stride.max(1),
102        y_stride: params.y_stride.max(1),
103        color_table_len: inputs.color_table.len() as u32,
104        _pad: 0,
105    };
106    let uniform_buffer = device.create_buffer_init(&wgpu::util::BufferInitDescriptor {
107        label: Some("surface-pack-uniforms"),
108        contents: bytemuck::bytes_of(&uniforms),
109        usage: wgpu::BufferUsages::UNIFORM | wgpu::BufferUsages::COPY_DST,
110    });
111
112    let bind_group_layout = device.create_bind_group_layout(&wgpu::BindGroupLayoutDescriptor {
113        label: Some("surface-pack-bind-layout"),
114        entries: &[
115            wgpu::BindGroupLayoutEntry {
116                binding: 0,
117                visibility: wgpu::ShaderStages::COMPUTE,
118                ty: wgpu::BindingType::Buffer {
119                    ty: wgpu::BufferBindingType::Storage { read_only: true },
120                    has_dynamic_offset: false,
121                    min_binding_size: None,
122                },
123                count: None,
124            },
125            wgpu::BindGroupLayoutEntry {
126                binding: 1,
127                visibility: wgpu::ShaderStages::COMPUTE,
128                ty: wgpu::BindingType::Buffer {
129                    ty: wgpu::BufferBindingType::Storage { read_only: true },
130                    has_dynamic_offset: false,
131                    min_binding_size: None,
132                },
133                count: None,
134            },
135            wgpu::BindGroupLayoutEntry {
136                binding: 2,
137                visibility: wgpu::ShaderStages::COMPUTE,
138                ty: wgpu::BindingType::Buffer {
139                    ty: wgpu::BufferBindingType::Storage { read_only: true },
140                    has_dynamic_offset: false,
141                    min_binding_size: None,
142                },
143                count: None,
144            },
145            wgpu::BindGroupLayoutEntry {
146                binding: 3,
147                visibility: wgpu::ShaderStages::COMPUTE,
148                ty: wgpu::BindingType::Buffer {
149                    ty: wgpu::BufferBindingType::Storage { read_only: true },
150                    has_dynamic_offset: false,
151                    min_binding_size: None,
152                },
153                count: None,
154            },
155            wgpu::BindGroupLayoutEntry {
156                binding: 4,
157                visibility: wgpu::ShaderStages::COMPUTE,
158                ty: wgpu::BindingType::Buffer {
159                    ty: wgpu::BufferBindingType::Storage { read_only: false },
160                    has_dynamic_offset: false,
161                    min_binding_size: None,
162                },
163                count: None,
164            },
165            wgpu::BindGroupLayoutEntry {
166                binding: 5,
167                visibility: wgpu::ShaderStages::COMPUTE,
168                ty: wgpu::BindingType::Buffer {
169                    ty: wgpu::BufferBindingType::Uniform,
170                    has_dynamic_offset: false,
171                    min_binding_size: None,
172                },
173                count: None,
174            },
175        ],
176    });
177
178    let pipeline_layout = device.create_pipeline_layout(&wgpu::PipelineLayoutDescriptor {
179        label: Some("surface-pack-pipeline-layout"),
180        bind_group_layouts: &[&bind_group_layout],
181        push_constant_ranges: &[],
182    });
183
184    let pipeline = device.create_compute_pipeline(&wgpu::ComputePipelineDescriptor {
185        label: Some("surface-pack-pipeline"),
186        layout: Some(&pipeline_layout),
187        module: &shader,
188        entry_point: "main",
189    });
190
191    let bind_group = device.create_bind_group(&wgpu::BindGroupDescriptor {
192        label: Some("surface-pack-bind-group"),
193        layout: &bind_group_layout,
194        entries: &[
195            wgpu::BindGroupEntry {
196                binding: 0,
197                resource: x_buffer.as_ref().as_entire_binding(),
198            },
199            wgpu::BindGroupEntry {
200                binding: 1,
201                resource: y_buffer.as_ref().as_entire_binding(),
202            },
203            wgpu::BindGroupEntry {
204                binding: 2,
205                resource: inputs.z_buffer.as_ref().as_entire_binding(),
206            },
207            wgpu::BindGroupEntry {
208                binding: 3,
209                resource: color_buffer.as_ref().as_entire_binding(),
210            },
211            wgpu::BindGroupEntry {
212                binding: 4,
213                resource: output_buffer.as_ref().as_entire_binding(),
214            },
215            wgpu::BindGroupEntry {
216                binding: 5,
217                resource: uniform_buffer.as_entire_binding(),
218            },
219        ],
220    });
221
222    let mut encoder = device.create_command_encoder(&wgpu::CommandEncoderDescriptor {
223        label: Some("surface-pack-encoder"),
224    });
225    {
226        let mut pass = encoder.begin_compute_pass(&wgpu::ComputePassDescriptor {
227            label: Some("surface-pack-pass"),
228            timestamp_writes: None,
229        });
230        pass.set_pipeline(&pipeline);
231        pass.set_bind_group(0, &bind_group, &[]);
232        let workgroups = vertex_count.div_ceil(workgroup_size);
233        pass.dispatch_workgroups(workgroups, 1, 1);
234    }
235    queue.submit(Some(encoder.finish()));
236
237    Ok(GpuVertexBuffer::new(output_buffer, vertex_count as usize))
238}
239
240fn compile_shader(
241    device: &Arc<wgpu::Device>,
242    workgroup_size: u32,
243    scalar: ScalarType,
244) -> wgpu::ShaderModule {
245    let template = match scalar {
246        ScalarType::F32 => shaders::surface::F32,
247        ScalarType::F64 => shaders::surface::F64,
248    };
249    let source = template.replace("{{WORKGROUP_SIZE}}", &workgroup_size.to_string());
250    device.create_shader_module(wgpu::ShaderModuleDescriptor {
251        label: Some("surface-pack-shader"),
252        source: wgpu::ShaderSource::Wgsl(source.into()),
253    })
254}
255
256#[cfg(test)]
257mod stress_tests {
258    use super::*;
259    use pollster::FutureExt;
260
261    fn maybe_device() -> Option<(Arc<wgpu::Device>, Arc<wgpu::Queue>)> {
262        if std::env::var("RUNMAT_PLOT_SKIP_GPU_TESTS").is_ok()
263            || std::env::var("RUNMAT_PLOT_FORCE_GPU_TESTS").is_err()
264        {
265            return None;
266        }
267        let instance = wgpu::Instance::default();
268        let adapter = instance
269            .request_adapter(&wgpu::RequestAdapterOptions {
270                power_preference: wgpu::PowerPreference::HighPerformance,
271                compatible_surface: None,
272                force_fallback_adapter: false,
273            })
274            .block_on()?;
275        let limits = adapter.limits();
276        let (device, queue) = adapter
277            .request_device(
278                &wgpu::DeviceDescriptor {
279                    label: Some("runmat-plot-surface-test-device"),
280                    required_features: wgpu::Features::empty(),
281                    required_limits: limits,
282                },
283                None,
284            )
285            .block_on()
286            .ok()?;
287        Some((Arc::new(device), Arc::new(queue)))
288    }
289
290    #[test]
291    fn gpu_packer_handles_large_surface() {
292        let Some((device, queue)) = maybe_device() else {
293            return;
294        };
295        let x_len = 2048u32;
296        let y_len = 2048u32;
297        let total = (x_len * y_len) as usize;
298        let x_axis: Vec<f32> = (0..x_len).map(|i| i as f32 * 0.1).collect();
299        let y_axis: Vec<f32> = (0..y_len).map(|i| i as f32 * 0.1).collect();
300        let mut z_data = vec![0.0f32; total];
301        for (idx, value) in z_data.iter_mut().enumerate() {
302            let x = (idx % x_len as usize) as f32 * 0.01;
303            let y = (idx / x_len as usize) as f32 * 0.01;
304            *value = (x.sin() + y.cos()) * 0.5;
305        }
306        let z_buffer = Arc::new(
307            device.create_buffer_init(&wgpu::util::BufferInitDescriptor {
308                label: Some("surface-test-z"),
309                contents: bytemuck::cast_slice(&z_data),
310                usage: wgpu::BufferUsages::STORAGE,
311            }),
312        );
313
314        let color_table: Vec<[f32; 4]> = (0..256)
315            .map(|i| {
316                let t = i as f32 / 255.0;
317                [t, 1.0 - t, 0.5, 1.0]
318            })
319            .collect();
320
321        let inputs = SurfaceGpuInputs {
322            x_axis: SurfaceAxis::F32(&x_axis),
323            y_axis: SurfaceAxis::F32(&y_axis),
324            z_buffer,
325            color_table: &color_table,
326            x_len,
327            y_len,
328            scalar: ScalarType::F32,
329        };
330        let stride = 8;
331        let lod_x_len = x_len.div_ceil(stride);
332        let lod_y_len = y_len.div_ceil(stride);
333        let params = SurfaceGpuParams {
334            min_z: -1.0,
335            max_z: 1.0,
336            alpha: 1.0,
337            flatten_z: false,
338            x_stride: stride,
339            y_stride: stride,
340            lod_x_len,
341            lod_y_len,
342        };
343
344        let gpu_vertices =
345            pack_surface_vertices(&device, &queue, &inputs, &params).expect("surface pack failed");
346        assert!(gpu_vertices.vertex_count > 0);
347        assert_eq!(gpu_vertices.vertex_count, (lod_x_len * lod_y_len) as usize);
348    }
349}