Skip to main content

runmat_plot/gpu/
stem.rs

1use crate::core::renderer::Vertex;
2use crate::core::scene::GpuVertexBuffer;
3use crate::gpu::shaders;
4use crate::gpu::{tuning, ScalarType};
5use crate::plots::line::LineStyle;
6use glam::Vec4;
7use std::sync::Arc;
8use wgpu::util::DeviceExt;
9
10#[derive(Clone, Debug)]
11pub struct StemGpuInputs {
12    pub x_buffer: Arc<wgpu::Buffer>,
13    pub y_buffer: Arc<wgpu::Buffer>,
14    pub len: u32,
15    pub scalar: ScalarType,
16}
17
18pub struct StemGpuParams {
19    pub color: Vec4,
20    pub baseline_color: Vec4,
21    pub baseline: f32,
22    pub baseline_visible: bool,
23    pub min_x: f32,
24    pub max_x: f32,
25    pub line_style: LineStyle,
26}
27
28#[repr(C)]
29#[derive(Clone, Copy, bytemuck::Pod, bytemuck::Zeroable)]
30struct StemUniforms {
31    color: [f32; 4],
32    baseline_color: [f32; 4],
33    baseline: f32,
34    min_x: f32,
35    max_x: f32,
36    point_count: u32,
37    line_style: u32,
38    baseline_visible: u32,
39}
40
41pub fn pack_vertices_from_xy(
42    device: &Arc<wgpu::Device>,
43    queue: &Arc<wgpu::Queue>,
44    inputs: &StemGpuInputs,
45    params: &StemGpuParams,
46) -> Result<GpuVertexBuffer, String> {
47    let workgroup_size = tuning::effective_workgroup_size();
48    let shader = compile_shader(device, workgroup_size, inputs.scalar);
49    let bind_group_layout = device.create_bind_group_layout(&wgpu::BindGroupLayoutDescriptor {
50        label: Some("stem-pack-bind-layout"),
51        entries: &[
52            storage_entry(0, true),
53            storage_entry(1, true),
54            storage_entry(2, false),
55            uniform_entry(3),
56        ],
57    });
58    let pipeline_layout = device.create_pipeline_layout(&wgpu::PipelineLayoutDescriptor {
59        label: Some("stem-pack-pipeline-layout"),
60        bind_group_layouts: &[&bind_group_layout],
61        push_constant_ranges: &[],
62    });
63    let pipeline =
64        device.create_compute_pipeline(&crate::wgpu_compat::wgpu_compute_pipeline_descriptor! {
65            label: Some("stem-pack-pipeline"),
66            layout: Some(&pipeline_layout),
67            module: &shader,
68            entry_point: "main",
69        });
70    let baseline_count = if params.baseline_visible { 2 } else { 0 };
71    let vertex_count = baseline_count as u64 + inputs.len as u64 * 2;
72    let output_buffer = Arc::new(device.create_buffer(&wgpu::BufferDescriptor {
73        label: Some("stem-gpu-vertices"),
74        size: vertex_count * std::mem::size_of::<Vertex>() as u64,
75        usage: wgpu::BufferUsages::STORAGE
76            | wgpu::BufferUsages::VERTEX
77            | wgpu::BufferUsages::COPY_DST
78            | wgpu::BufferUsages::COPY_SRC,
79        mapped_at_creation: false,
80    }));
81    let uniforms = StemUniforms {
82        color: params.color.to_array(),
83        baseline_color: params.baseline_color.to_array(),
84        baseline: params.baseline,
85        min_x: params.min_x,
86        max_x: params.max_x,
87        point_count: inputs.len,
88        line_style: line_style_code(params.line_style),
89        baseline_visible: if params.baseline_visible { 1 } else { 0 },
90    };
91    let uniform_buffer = device.create_buffer_init(&wgpu::util::BufferInitDescriptor {
92        label: Some("stem-pack-uniforms"),
93        contents: bytemuck::bytes_of(&uniforms),
94        usage: wgpu::BufferUsages::UNIFORM | wgpu::BufferUsages::COPY_DST,
95    });
96    let bind_group = device.create_bind_group(&wgpu::BindGroupDescriptor {
97        label: Some("stem-pack-bind-group"),
98        layout: &bind_group_layout,
99        entries: &[
100            wgpu::BindGroupEntry {
101                binding: 0,
102                resource: inputs.x_buffer.as_entire_binding(),
103            },
104            wgpu::BindGroupEntry {
105                binding: 1,
106                resource: inputs.y_buffer.as_entire_binding(),
107            },
108            wgpu::BindGroupEntry {
109                binding: 2,
110                resource: output_buffer.as_entire_binding(),
111            },
112            wgpu::BindGroupEntry {
113                binding: 3,
114                resource: uniform_buffer.as_entire_binding(),
115            },
116        ],
117    });
118    let mut encoder = device.create_command_encoder(&wgpu::CommandEncoderDescriptor {
119        label: Some("stem-pack-encoder"),
120    });
121    {
122        let mut pass = encoder.begin_compute_pass(&wgpu::ComputePassDescriptor {
123            label: Some("stem-pack-pass"),
124            timestamp_writes: None,
125        });
126        pass.set_pipeline(&pipeline);
127        pass.set_bind_group(0, &bind_group, &[]);
128        pass.dispatch_workgroups(inputs.len.div_ceil(workgroup_size), 1, 1);
129    }
130    queue.submit(Some(encoder.finish()));
131    Ok(GpuVertexBuffer::new(output_buffer, vertex_count as usize))
132}
133
134fn compile_shader(
135    device: &Arc<wgpu::Device>,
136    workgroup_size: u32,
137    scalar: ScalarType,
138) -> wgpu::ShaderModule {
139    let template = match scalar {
140        ScalarType::F32 => shaders::stem::F32,
141        ScalarType::F64 => shaders::stem::F64,
142    };
143    let source = template.replace("{{WORKGROUP_SIZE}}", &workgroup_size.to_string());
144    device.create_shader_module(wgpu::ShaderModuleDescriptor {
145        label: Some("stem-pack-shader"),
146        source: wgpu::ShaderSource::Wgsl(source.into()),
147    })
148}
149
150fn storage_entry(binding: u32, read_only: bool) -> wgpu::BindGroupLayoutEntry {
151    wgpu::BindGroupLayoutEntry {
152        binding,
153        visibility: wgpu::ShaderStages::COMPUTE,
154        ty: wgpu::BindingType::Buffer {
155            ty: wgpu::BufferBindingType::Storage { read_only },
156            has_dynamic_offset: false,
157            min_binding_size: None,
158        },
159        count: None,
160    }
161}
162fn uniform_entry(binding: u32) -> wgpu::BindGroupLayoutEntry {
163    wgpu::BindGroupLayoutEntry {
164        binding,
165        visibility: wgpu::ShaderStages::COMPUTE,
166        ty: wgpu::BindingType::Buffer {
167            ty: wgpu::BufferBindingType::Uniform,
168            has_dynamic_offset: false,
169            min_binding_size: None,
170        },
171        count: None,
172    }
173}
174fn line_style_code(style: LineStyle) -> u32 {
175    match style {
176        LineStyle::Solid => 0,
177        LineStyle::Dashed => 1,
178        LineStyle::Dotted => 2,
179        LineStyle::DashDot => 3,
180    }
181}