Skip to main content

runmat_plot/gpu/
surface.rs

1use crate::core::renderer::Vertex;
2use crate::core::scene::GpuVertexBuffer;
3use crate::gpu::axis::{axis_storage_buffer, AxisData};
4use crate::gpu::shaders;
5use crate::gpu::{tuning, ScalarType};
6use std::sync::Arc;
7use wgpu::util::DeviceExt;
8
9/// Axis data source used by the GPU surface vertex packer.
10pub type SurfaceAxis<'a> = AxisData<'a>;
11
12/// Inputs required to pack surface vertices directly on the GPU.
13pub struct SurfaceGpuInputs<'a> {
14    pub x_axis: SurfaceAxis<'a>,
15    pub y_axis: SurfaceAxis<'a>,
16    pub z_buffer: Arc<wgpu::Buffer>,
17    pub color_table: &'a [[f32; 4]],
18    pub x_len: u32,
19    pub y_len: u32,
20    pub scalar: ScalarType,
21}
22
23/// Parameters describing how the GPU vertices should be generated.
24pub struct SurfaceGpuParams {
25    pub min_z: f32,
26    pub max_z: f32,
27    pub alpha: f32,
28    pub flatten_z: bool,
29    pub x_stride: u32,
30    pub y_stride: u32,
31    pub lod_x_len: u32,
32    pub lod_y_len: u32,
33}
34
35#[repr(C)]
36#[derive(Clone, Copy, bytemuck::Pod, bytemuck::Zeroable)]
37struct SurfaceUniforms {
38    min_z: f32,
39    max_z: f32,
40    alpha: f32,
41    flatten: u32,
42    x_len: u32,
43    y_len: u32,
44    lod_x_len: u32,
45    lod_y_len: u32,
46    x_stride: u32,
47    y_stride: u32,
48    color_table_len: u32,
49    _pad: u32,
50}
51
52/// Builds a GPU-resident vertex buffer for surface plots directly from provider-owned Z data.
53pub fn pack_surface_vertices(
54    device: &Arc<wgpu::Device>,
55    queue: &Arc<wgpu::Queue>,
56    inputs: &SurfaceGpuInputs<'_>,
57    params: &SurfaceGpuParams,
58) -> Result<GpuVertexBuffer, String> {
59    if inputs.x_len < 2 || inputs.y_len < 2 {
60        return Err("surf: axis vectors must contain at least two elements".to_string());
61    }
62
63    let workgroup_size = tuning::effective_workgroup_size();
64    let shader = compile_shader(device, workgroup_size, inputs.scalar);
65
66    let x_buffer = axis_storage_buffer(device, "surface-x-axis", &inputs.x_axis, inputs.scalar)?;
67    let y_buffer = axis_storage_buffer(device, "surface-y-axis", &inputs.y_axis, inputs.scalar)?;
68
69    let color_buffer = Arc::new(
70        device.create_buffer_init(&wgpu::util::BufferInitDescriptor {
71            label: Some("surface-color-table"),
72            contents: bytemuck::cast_slice(inputs.color_table),
73            usage: wgpu::BufferUsages::STORAGE | wgpu::BufferUsages::COPY_DST,
74        }),
75    );
76
77    let lod_x_len = params.lod_x_len.max(1);
78    let lod_y_len = params.lod_y_len.max(1);
79    let vertex_count = lod_x_len
80        .checked_mul(lod_y_len)
81        .ok_or_else(|| "surf: grid dimensions overflowed vertex count".to_string())?;
82    let output_size = vertex_count as u64 * std::mem::size_of::<Vertex>() as u64;
83    let output_buffer = Arc::new(device.create_buffer(&wgpu::BufferDescriptor {
84        label: Some("surface-gpu-vertices"),
85        size: output_size,
86        usage: wgpu::BufferUsages::STORAGE
87            | wgpu::BufferUsages::VERTEX
88            | wgpu::BufferUsages::COPY_DST,
89        mapped_at_creation: false,
90    }));
91
92    let min_z = if params.min_z.is_finite() {
93        params.min_z
94    } else {
95        tracing::warn!(
96            target: "runmat_plot::surface_gpu",
97            min_z = params.min_z,
98            "non-finite min_z received; sanitizing to 0.0"
99        );
100        0.0
101    };
102    let mut max_z = if params.max_z.is_finite() {
103        params.max_z
104    } else {
105        tracing::warn!(
106            target: "runmat_plot::surface_gpu",
107            max_z = params.max_z,
108            "non-finite max_z received; sanitizing to min_z + 1.0"
109        );
110        min_z + 1.0
111    };
112    if max_z <= min_z {
113        tracing::warn!(
114            target: "runmat_plot::surface_gpu",
115            min_z,
116            max_z,
117            "invalid z range received; forcing epsilon span"
118        );
119        max_z = min_z + 1e-6;
120    }
121
122    let uniforms = SurfaceUniforms {
123        min_z,
124        max_z,
125        alpha: params.alpha,
126        flatten: if params.flatten_z { 1 } else { 0 },
127        x_len: inputs.x_len,
128        y_len: inputs.y_len,
129        lod_x_len,
130        lod_y_len,
131        x_stride: params.x_stride.max(1),
132        y_stride: params.y_stride.max(1),
133        color_table_len: inputs.color_table.len() as u32,
134        _pad: 0,
135    };
136    let uniform_buffer = device.create_buffer_init(&wgpu::util::BufferInitDescriptor {
137        label: Some("surface-pack-uniforms"),
138        contents: bytemuck::bytes_of(&uniforms),
139        usage: wgpu::BufferUsages::UNIFORM | wgpu::BufferUsages::COPY_DST,
140    });
141
142    let bind_group_layout = device.create_bind_group_layout(&wgpu::BindGroupLayoutDescriptor {
143        label: Some("surface-pack-bind-layout"),
144        entries: &[
145            wgpu::BindGroupLayoutEntry {
146                binding: 0,
147                visibility: wgpu::ShaderStages::COMPUTE,
148                ty: wgpu::BindingType::Buffer {
149                    ty: wgpu::BufferBindingType::Storage { read_only: true },
150                    has_dynamic_offset: false,
151                    min_binding_size: None,
152                },
153                count: None,
154            },
155            wgpu::BindGroupLayoutEntry {
156                binding: 1,
157                visibility: wgpu::ShaderStages::COMPUTE,
158                ty: wgpu::BindingType::Buffer {
159                    ty: wgpu::BufferBindingType::Storage { read_only: true },
160                    has_dynamic_offset: false,
161                    min_binding_size: None,
162                },
163                count: None,
164            },
165            wgpu::BindGroupLayoutEntry {
166                binding: 2,
167                visibility: wgpu::ShaderStages::COMPUTE,
168                ty: wgpu::BindingType::Buffer {
169                    ty: wgpu::BufferBindingType::Storage { read_only: true },
170                    has_dynamic_offset: false,
171                    min_binding_size: None,
172                },
173                count: None,
174            },
175            wgpu::BindGroupLayoutEntry {
176                binding: 3,
177                visibility: wgpu::ShaderStages::COMPUTE,
178                ty: wgpu::BindingType::Buffer {
179                    ty: wgpu::BufferBindingType::Storage { read_only: true },
180                    has_dynamic_offset: false,
181                    min_binding_size: None,
182                },
183                count: None,
184            },
185            wgpu::BindGroupLayoutEntry {
186                binding: 4,
187                visibility: wgpu::ShaderStages::COMPUTE,
188                ty: wgpu::BindingType::Buffer {
189                    ty: wgpu::BufferBindingType::Storage { read_only: false },
190                    has_dynamic_offset: false,
191                    min_binding_size: None,
192                },
193                count: None,
194            },
195            wgpu::BindGroupLayoutEntry {
196                binding: 5,
197                visibility: wgpu::ShaderStages::COMPUTE,
198                ty: wgpu::BindingType::Buffer {
199                    ty: wgpu::BufferBindingType::Uniform,
200                    has_dynamic_offset: false,
201                    min_binding_size: None,
202                },
203                count: None,
204            },
205        ],
206    });
207
208    let pipeline_layout = device.create_pipeline_layout(&wgpu::PipelineLayoutDescriptor {
209        label: Some("surface-pack-pipeline-layout"),
210        bind_group_layouts: &[&bind_group_layout],
211        push_constant_ranges: &[],
212    });
213
214    let pipeline =
215        device.create_compute_pipeline(&crate::wgpu_compat::wgpu_compute_pipeline_descriptor! {
216            label: Some("surface-pack-pipeline"),
217            layout: Some(&pipeline_layout),
218            module: &shader,
219            entry_point: "main",
220        });
221
222    let bind_group = device.create_bind_group(&wgpu::BindGroupDescriptor {
223        label: Some("surface-pack-bind-group"),
224        layout: &bind_group_layout,
225        entries: &[
226            wgpu::BindGroupEntry {
227                binding: 0,
228                resource: x_buffer.as_ref().as_entire_binding(),
229            },
230            wgpu::BindGroupEntry {
231                binding: 1,
232                resource: y_buffer.as_ref().as_entire_binding(),
233            },
234            wgpu::BindGroupEntry {
235                binding: 2,
236                resource: inputs.z_buffer.as_ref().as_entire_binding(),
237            },
238            wgpu::BindGroupEntry {
239                binding: 3,
240                resource: color_buffer.as_ref().as_entire_binding(),
241            },
242            wgpu::BindGroupEntry {
243                binding: 4,
244                resource: output_buffer.as_ref().as_entire_binding(),
245            },
246            wgpu::BindGroupEntry {
247                binding: 5,
248                resource: uniform_buffer.as_entire_binding(),
249            },
250        ],
251    });
252
253    let mut encoder = device.create_command_encoder(&wgpu::CommandEncoderDescriptor {
254        label: Some("surface-pack-encoder"),
255    });
256    {
257        let mut pass = encoder.begin_compute_pass(&wgpu::ComputePassDescriptor {
258            label: Some("surface-pack-pass"),
259            timestamp_writes: None,
260        });
261        pass.set_pipeline(&pipeline);
262        pass.set_bind_group(0, &bind_group, &[]);
263        let workgroups = vertex_count.div_ceil(workgroup_size);
264        pass.dispatch_workgroups(workgroups, 1, 1);
265    }
266    queue.submit(Some(encoder.finish()));
267
268    Ok(GpuVertexBuffer::new(output_buffer, vertex_count as usize))
269}
270
271fn compile_shader(
272    device: &Arc<wgpu::Device>,
273    workgroup_size: u32,
274    scalar: ScalarType,
275) -> wgpu::ShaderModule {
276    let template = match scalar {
277        ScalarType::F32 => shaders::surface::F32,
278        ScalarType::F64 => shaders::surface::F64,
279    };
280    let source = template.replace("{{WORKGROUP_SIZE}}", &workgroup_size.to_string());
281    device.create_shader_module(wgpu::ShaderModuleDescriptor {
282        label: Some("surface-pack-shader"),
283        source: wgpu::ShaderSource::Wgsl(source.into()),
284    })
285}
286
287#[cfg(test)]
288mod stress_tests {
289    use super::*;
290    use pollster::FutureExt;
291
292    fn maybe_device() -> Option<(Arc<wgpu::Device>, Arc<wgpu::Queue>)> {
293        if std::env::var("RUNMAT_PLOT_SKIP_GPU_TESTS").is_ok()
294            || std::env::var("RUNMAT_PLOT_FORCE_GPU_TESTS").is_err()
295        {
296            return None;
297        }
298        let instance = wgpu::Instance::default();
299        let adapter = instance
300            .request_adapter(&wgpu::RequestAdapterOptions {
301                power_preference: wgpu::PowerPreference::HighPerformance,
302                compatible_surface: None,
303                force_fallback_adapter: false,
304            })
305            .block_on()?;
306        let limits = adapter.limits();
307        let (device, queue) = adapter
308            .request_device(
309                &crate::wgpu_compat::device_descriptor(
310                    Some("runmat-plot-surface-test-device"),
311                    wgpu::Features::empty(),
312                    limits,
313                ),
314                None,
315            )
316            .block_on()
317            .ok()?;
318        Some((Arc::new(device), Arc::new(queue)))
319    }
320
321    #[test]
322    fn gpu_packer_handles_large_surface() {
323        let Some((device, queue)) = maybe_device() else {
324            return;
325        };
326        let x_len = 2048u32;
327        let y_len = 2048u32;
328        let total = (x_len * y_len) as usize;
329        let x_axis: Vec<f32> = (0..x_len).map(|i| i as f32 * 0.1).collect();
330        let y_axis: Vec<f32> = (0..y_len).map(|i| i as f32 * 0.1).collect();
331        let mut z_data = vec![0.0f32; total];
332        for (idx, value) in z_data.iter_mut().enumerate() {
333            let x = (idx % x_len as usize) as f32 * 0.01;
334            let y = (idx / x_len as usize) as f32 * 0.01;
335            *value = (x.sin() + y.cos()) * 0.5;
336        }
337        let z_buffer = Arc::new(
338            device.create_buffer_init(&wgpu::util::BufferInitDescriptor {
339                label: Some("surface-test-z"),
340                contents: bytemuck::cast_slice(&z_data),
341                usage: wgpu::BufferUsages::STORAGE,
342            }),
343        );
344
345        let color_table: Vec<[f32; 4]> = (0..256)
346            .map(|i| {
347                let t = i as f32 / 255.0;
348                [t, 1.0 - t, 0.5, 1.0]
349            })
350            .collect();
351
352        let inputs = SurfaceGpuInputs {
353            x_axis: SurfaceAxis::F32(&x_axis),
354            y_axis: SurfaceAxis::F32(&y_axis),
355            z_buffer,
356            color_table: &color_table,
357            x_len,
358            y_len,
359            scalar: ScalarType::F32,
360        };
361        let stride = 8;
362        let lod_x_len = x_len.div_ceil(stride);
363        let lod_y_len = y_len.div_ceil(stride);
364        let params = SurfaceGpuParams {
365            min_z: -1.0,
366            max_z: 1.0,
367            alpha: 1.0,
368            flatten_z: false,
369            x_stride: stride,
370            y_stride: stride,
371            lod_x_len,
372            lod_y_len,
373        };
374
375        let gpu_vertices =
376            pack_surface_vertices(&device, &queue, &inputs, &params).expect("surface pack failed");
377        assert!(gpu_vertices.vertex_count > 0);
378        assert_eq!(gpu_vertices.vertex_count, (lod_x_len * lod_y_len) as usize);
379    }
380}