Skip to main content

runmat_plot/gpu/
surface.rs

1use crate::core::renderer::Vertex;
2use crate::core::scene::GpuVertexBuffer;
3use crate::gpu::axis::{axis_storage_buffer, AxisData};
4use crate::gpu::shaders;
5use crate::gpu::{tuning, ScalarType};
6use std::sync::Arc;
7use wgpu::util::DeviceExt;
8
9/// Axis data source used by the GPU surface vertex packer.
10pub type SurfaceAxis<'a> = AxisData<'a>;
11
12/// Inputs required to pack surface vertices directly on the GPU.
13pub struct SurfaceGpuInputs<'a> {
14    pub x_axis: SurfaceAxis<'a>,
15    pub y_axis: SurfaceAxis<'a>,
16    pub z_buffer: Arc<wgpu::Buffer>,
17    pub color_table: &'a [[f32; 4]],
18    pub x_len: u32,
19    pub y_len: u32,
20    pub scalar: ScalarType,
21}
22
23/// Parameters describing how the GPU vertices should be generated.
24pub struct SurfaceGpuParams {
25    pub min_z: f32,
26    pub max_z: f32,
27    pub alpha: f32,
28    pub flatten_z: bool,
29    pub x_stride: u32,
30    pub y_stride: u32,
31    pub lod_x_len: u32,
32    pub lod_y_len: u32,
33}
34
35#[repr(C)]
36#[derive(Clone, Copy, bytemuck::Pod, bytemuck::Zeroable)]
37struct SurfaceUniforms {
38    min_z: f32,
39    max_z: f32,
40    alpha: f32,
41    flatten: u32,
42    x_len: u32,
43    y_len: u32,
44    lod_x_len: u32,
45    lod_y_len: u32,
46    x_stride: u32,
47    y_stride: u32,
48    color_table_len: u32,
49    _pad: u32,
50}
51
52/// Builds a GPU-resident vertex buffer for surface plots directly from provider-owned Z data.
53pub fn pack_surface_vertices(
54    device: &Arc<wgpu::Device>,
55    queue: &Arc<wgpu::Queue>,
56    inputs: &SurfaceGpuInputs<'_>,
57    params: &SurfaceGpuParams,
58) -> Result<GpuVertexBuffer, String> {
59    if inputs.x_len < 2 || inputs.y_len < 2 {
60        return Err("surf: axis vectors must contain at least two elements".to_string());
61    }
62
63    let workgroup_size = tuning::effective_workgroup_size();
64    let shader = compile_shader(device, workgroup_size, inputs.scalar);
65
66    let x_buffer = axis_storage_buffer(device, "surface-x-axis", &inputs.x_axis, inputs.scalar)?;
67    let y_buffer = axis_storage_buffer(device, "surface-y-axis", &inputs.y_axis, inputs.scalar)?;
68
69    let color_buffer = Arc::new(
70        device.create_buffer_init(&wgpu::util::BufferInitDescriptor {
71            label: Some("surface-color-table"),
72            contents: bytemuck::cast_slice(inputs.color_table),
73            usage: wgpu::BufferUsages::STORAGE | wgpu::BufferUsages::COPY_DST,
74        }),
75    );
76
77    let lod_x_len = params.lod_x_len.max(1);
78    let lod_y_len = params.lod_y_len.max(1);
79    let vertex_count = lod_x_len
80        .checked_mul(lod_y_len)
81        .ok_or_else(|| "surf: grid dimensions overflowed vertex count".to_string())?;
82    let output_size = vertex_count as u64 * std::mem::size_of::<Vertex>() as u64;
83    let output_buffer = Arc::new(device.create_buffer(&wgpu::BufferDescriptor {
84        label: Some("surface-gpu-vertices"),
85        size: output_size,
86        usage: wgpu::BufferUsages::STORAGE
87            | wgpu::BufferUsages::VERTEX
88            | wgpu::BufferUsages::COPY_DST,
89        mapped_at_creation: false,
90    }));
91
92    let min_z = if params.min_z.is_finite() {
93        params.min_z
94    } else {
95        tracing::warn!(
96            target: "runmat_plot::surface_gpu",
97            min_z = params.min_z,
98            "non-finite min_z received; sanitizing to 0.0"
99        );
100        0.0
101    };
102    let mut max_z = if params.max_z.is_finite() {
103        params.max_z
104    } else {
105        tracing::warn!(
106            target: "runmat_plot::surface_gpu",
107            max_z = params.max_z,
108            "non-finite max_z received; sanitizing to min_z + 1.0"
109        );
110        min_z + 1.0
111    };
112    if max_z <= min_z {
113        tracing::warn!(
114            target: "runmat_plot::surface_gpu",
115            min_z,
116            max_z,
117            "invalid z range received; forcing epsilon span"
118        );
119        max_z = min_z + 1e-6;
120    }
121
122    let uniforms = SurfaceUniforms {
123        min_z,
124        max_z,
125        alpha: params.alpha,
126        flatten: if params.flatten_z { 1 } else { 0 },
127        x_len: inputs.x_len,
128        y_len: inputs.y_len,
129        lod_x_len,
130        lod_y_len,
131        x_stride: params.x_stride.max(1),
132        y_stride: params.y_stride.max(1),
133        color_table_len: inputs.color_table.len() as u32,
134        _pad: 0,
135    };
136    let uniform_buffer = device.create_buffer_init(&wgpu::util::BufferInitDescriptor {
137        label: Some("surface-pack-uniforms"),
138        contents: bytemuck::bytes_of(&uniforms),
139        usage: wgpu::BufferUsages::UNIFORM | wgpu::BufferUsages::COPY_DST,
140    });
141
142    let bind_group_layout = device.create_bind_group_layout(&wgpu::BindGroupLayoutDescriptor {
143        label: Some("surface-pack-bind-layout"),
144        entries: &[
145            wgpu::BindGroupLayoutEntry {
146                binding: 0,
147                visibility: wgpu::ShaderStages::COMPUTE,
148                ty: wgpu::BindingType::Buffer {
149                    ty: wgpu::BufferBindingType::Storage { read_only: true },
150                    has_dynamic_offset: false,
151                    min_binding_size: None,
152                },
153                count: None,
154            },
155            wgpu::BindGroupLayoutEntry {
156                binding: 1,
157                visibility: wgpu::ShaderStages::COMPUTE,
158                ty: wgpu::BindingType::Buffer {
159                    ty: wgpu::BufferBindingType::Storage { read_only: true },
160                    has_dynamic_offset: false,
161                    min_binding_size: None,
162                },
163                count: None,
164            },
165            wgpu::BindGroupLayoutEntry {
166                binding: 2,
167                visibility: wgpu::ShaderStages::COMPUTE,
168                ty: wgpu::BindingType::Buffer {
169                    ty: wgpu::BufferBindingType::Storage { read_only: true },
170                    has_dynamic_offset: false,
171                    min_binding_size: None,
172                },
173                count: None,
174            },
175            wgpu::BindGroupLayoutEntry {
176                binding: 3,
177                visibility: wgpu::ShaderStages::COMPUTE,
178                ty: wgpu::BindingType::Buffer {
179                    ty: wgpu::BufferBindingType::Storage { read_only: true },
180                    has_dynamic_offset: false,
181                    min_binding_size: None,
182                },
183                count: None,
184            },
185            wgpu::BindGroupLayoutEntry {
186                binding: 4,
187                visibility: wgpu::ShaderStages::COMPUTE,
188                ty: wgpu::BindingType::Buffer {
189                    ty: wgpu::BufferBindingType::Storage { read_only: false },
190                    has_dynamic_offset: false,
191                    min_binding_size: None,
192                },
193                count: None,
194            },
195            wgpu::BindGroupLayoutEntry {
196                binding: 5,
197                visibility: wgpu::ShaderStages::COMPUTE,
198                ty: wgpu::BindingType::Buffer {
199                    ty: wgpu::BufferBindingType::Uniform,
200                    has_dynamic_offset: false,
201                    min_binding_size: None,
202                },
203                count: None,
204            },
205        ],
206    });
207
208    let pipeline_layout = device.create_pipeline_layout(&wgpu::PipelineLayoutDescriptor {
209        label: Some("surface-pack-pipeline-layout"),
210        bind_group_layouts: &[&bind_group_layout],
211        push_constant_ranges: &[],
212    });
213
214    let pipeline = device.create_compute_pipeline(&wgpu::ComputePipelineDescriptor {
215        label: Some("surface-pack-pipeline"),
216        layout: Some(&pipeline_layout),
217        module: &shader,
218        entry_point: "main",
219    });
220
221    let bind_group = device.create_bind_group(&wgpu::BindGroupDescriptor {
222        label: Some("surface-pack-bind-group"),
223        layout: &bind_group_layout,
224        entries: &[
225            wgpu::BindGroupEntry {
226                binding: 0,
227                resource: x_buffer.as_ref().as_entire_binding(),
228            },
229            wgpu::BindGroupEntry {
230                binding: 1,
231                resource: y_buffer.as_ref().as_entire_binding(),
232            },
233            wgpu::BindGroupEntry {
234                binding: 2,
235                resource: inputs.z_buffer.as_ref().as_entire_binding(),
236            },
237            wgpu::BindGroupEntry {
238                binding: 3,
239                resource: color_buffer.as_ref().as_entire_binding(),
240            },
241            wgpu::BindGroupEntry {
242                binding: 4,
243                resource: output_buffer.as_ref().as_entire_binding(),
244            },
245            wgpu::BindGroupEntry {
246                binding: 5,
247                resource: uniform_buffer.as_entire_binding(),
248            },
249        ],
250    });
251
252    let mut encoder = device.create_command_encoder(&wgpu::CommandEncoderDescriptor {
253        label: Some("surface-pack-encoder"),
254    });
255    {
256        let mut pass = encoder.begin_compute_pass(&wgpu::ComputePassDescriptor {
257            label: Some("surface-pack-pass"),
258            timestamp_writes: None,
259        });
260        pass.set_pipeline(&pipeline);
261        pass.set_bind_group(0, &bind_group, &[]);
262        let workgroups = vertex_count.div_ceil(workgroup_size);
263        pass.dispatch_workgroups(workgroups, 1, 1);
264    }
265    queue.submit(Some(encoder.finish()));
266
267    Ok(GpuVertexBuffer::new(output_buffer, vertex_count as usize))
268}
269
270fn compile_shader(
271    device: &Arc<wgpu::Device>,
272    workgroup_size: u32,
273    scalar: ScalarType,
274) -> wgpu::ShaderModule {
275    let template = match scalar {
276        ScalarType::F32 => shaders::surface::F32,
277        ScalarType::F64 => shaders::surface::F64,
278    };
279    let source = template.replace("{{WORKGROUP_SIZE}}", &workgroup_size.to_string());
280    device.create_shader_module(wgpu::ShaderModuleDescriptor {
281        label: Some("surface-pack-shader"),
282        source: wgpu::ShaderSource::Wgsl(source.into()),
283    })
284}
285
286#[cfg(test)]
287mod stress_tests {
288    use super::*;
289    use pollster::FutureExt;
290
291    fn maybe_device() -> Option<(Arc<wgpu::Device>, Arc<wgpu::Queue>)> {
292        if std::env::var("RUNMAT_PLOT_SKIP_GPU_TESTS").is_ok()
293            || std::env::var("RUNMAT_PLOT_FORCE_GPU_TESTS").is_err()
294        {
295            return None;
296        }
297        let instance = wgpu::Instance::default();
298        let adapter = instance
299            .request_adapter(&wgpu::RequestAdapterOptions {
300                power_preference: wgpu::PowerPreference::HighPerformance,
301                compatible_surface: None,
302                force_fallback_adapter: false,
303            })
304            .block_on()?;
305        let limits = adapter.limits();
306        let (device, queue) = adapter
307            .request_device(
308                &wgpu::DeviceDescriptor {
309                    label: Some("runmat-plot-surface-test-device"),
310                    required_features: wgpu::Features::empty(),
311                    required_limits: limits,
312                },
313                None,
314            )
315            .block_on()
316            .ok()?;
317        Some((Arc::new(device), Arc::new(queue)))
318    }
319
320    #[test]
321    fn gpu_packer_handles_large_surface() {
322        let Some((device, queue)) = maybe_device() else {
323            return;
324        };
325        let x_len = 2048u32;
326        let y_len = 2048u32;
327        let total = (x_len * y_len) as usize;
328        let x_axis: Vec<f32> = (0..x_len).map(|i| i as f32 * 0.1).collect();
329        let y_axis: Vec<f32> = (0..y_len).map(|i| i as f32 * 0.1).collect();
330        let mut z_data = vec![0.0f32; total];
331        for (idx, value) in z_data.iter_mut().enumerate() {
332            let x = (idx % x_len as usize) as f32 * 0.01;
333            let y = (idx / x_len as usize) as f32 * 0.01;
334            *value = (x.sin() + y.cos()) * 0.5;
335        }
336        let z_buffer = Arc::new(
337            device.create_buffer_init(&wgpu::util::BufferInitDescriptor {
338                label: Some("surface-test-z"),
339                contents: bytemuck::cast_slice(&z_data),
340                usage: wgpu::BufferUsages::STORAGE,
341            }),
342        );
343
344        let color_table: Vec<[f32; 4]> = (0..256)
345            .map(|i| {
346                let t = i as f32 / 255.0;
347                [t, 1.0 - t, 0.5, 1.0]
348            })
349            .collect();
350
351        let inputs = SurfaceGpuInputs {
352            x_axis: SurfaceAxis::F32(&x_axis),
353            y_axis: SurfaceAxis::F32(&y_axis),
354            z_buffer,
355            color_table: &color_table,
356            x_len,
357            y_len,
358            scalar: ScalarType::F32,
359        };
360        let stride = 8;
361        let lod_x_len = x_len.div_ceil(stride);
362        let lod_y_len = y_len.div_ceil(stride);
363        let params = SurfaceGpuParams {
364            min_z: -1.0,
365            max_z: 1.0,
366            alpha: 1.0,
367            flatten_z: false,
368            x_stride: stride,
369            y_stride: stride,
370            lod_x_len,
371            lod_y_len,
372        };
373
374        let gpu_vertices =
375            pack_surface_vertices(&device, &queue, &inputs, &params).expect("surface pack failed");
376        assert!(gpu_vertices.vertex_count > 0);
377        assert_eq!(gpu_vertices.vertex_count, (lod_x_len * lod_y_len) as usize);
378    }
379}