Skip to main content

runmat_plot/gpu/
stem.rs

1use crate::core::renderer::Vertex;
2use crate::core::scene::GpuVertexBuffer;
3use crate::gpu::shaders;
4use crate::gpu::{tuning, ScalarType};
5use crate::plots::line::LineStyle;
6use glam::Vec4;
7use std::sync::Arc;
8use wgpu::util::DeviceExt;
9
10pub struct StemGpuInputs {
11    pub x_buffer: Arc<wgpu::Buffer>,
12    pub y_buffer: Arc<wgpu::Buffer>,
13    pub len: u32,
14    pub scalar: ScalarType,
15}
16
17pub struct StemGpuParams {
18    pub color: Vec4,
19    pub baseline_color: Vec4,
20    pub baseline: f32,
21    pub baseline_visible: bool,
22    pub min_x: f32,
23    pub max_x: f32,
24    pub line_style: LineStyle,
25}
26
27#[repr(C)]
28#[derive(Clone, Copy, bytemuck::Pod, bytemuck::Zeroable)]
29struct StemUniforms {
30    color: [f32; 4],
31    baseline_color: [f32; 4],
32    baseline: f32,
33    min_x: f32,
34    max_x: f32,
35    point_count: u32,
36    line_style: u32,
37    baseline_visible: u32,
38}
39
40pub fn pack_vertices_from_xy(
41    device: &Arc<wgpu::Device>,
42    queue: &Arc<wgpu::Queue>,
43    inputs: &StemGpuInputs,
44    params: &StemGpuParams,
45) -> Result<GpuVertexBuffer, String> {
46    let workgroup_size = tuning::effective_workgroup_size();
47    let shader = compile_shader(device, workgroup_size, inputs.scalar);
48    let bind_group_layout = device.create_bind_group_layout(&wgpu::BindGroupLayoutDescriptor {
49        label: Some("stem-pack-bind-layout"),
50        entries: &[
51            storage_entry(0, true),
52            storage_entry(1, true),
53            storage_entry(2, false),
54            uniform_entry(3),
55        ],
56    });
57    let pipeline_layout = device.create_pipeline_layout(&wgpu::PipelineLayoutDescriptor {
58        label: Some("stem-pack-pipeline-layout"),
59        bind_group_layouts: &[&bind_group_layout],
60        push_constant_ranges: &[],
61    });
62    let pipeline =
63        device.create_compute_pipeline(&crate::wgpu_compat::wgpu_compute_pipeline_descriptor! {
64            label: Some("stem-pack-pipeline"),
65            layout: Some(&pipeline_layout),
66            module: &shader,
67            entry_point: "main",
68        });
69    let baseline_count = if params.baseline_visible { 2 } else { 0 };
70    let vertex_count = baseline_count as u64 + inputs.len as u64 * 2;
71    let output_buffer = Arc::new(device.create_buffer(&wgpu::BufferDescriptor {
72        label: Some("stem-gpu-vertices"),
73        size: vertex_count * std::mem::size_of::<Vertex>() as u64,
74        usage: wgpu::BufferUsages::STORAGE
75            | wgpu::BufferUsages::VERTEX
76            | wgpu::BufferUsages::COPY_DST,
77        mapped_at_creation: false,
78    }));
79    let uniforms = StemUniforms {
80        color: params.color.to_array(),
81        baseline_color: params.baseline_color.to_array(),
82        baseline: params.baseline,
83        min_x: params.min_x,
84        max_x: params.max_x,
85        point_count: inputs.len,
86        line_style: line_style_code(params.line_style),
87        baseline_visible: if params.baseline_visible { 1 } else { 0 },
88    };
89    let uniform_buffer = device.create_buffer_init(&wgpu::util::BufferInitDescriptor {
90        label: Some("stem-pack-uniforms"),
91        contents: bytemuck::bytes_of(&uniforms),
92        usage: wgpu::BufferUsages::UNIFORM | wgpu::BufferUsages::COPY_DST,
93    });
94    let bind_group = device.create_bind_group(&wgpu::BindGroupDescriptor {
95        label: Some("stem-pack-bind-group"),
96        layout: &bind_group_layout,
97        entries: &[
98            wgpu::BindGroupEntry {
99                binding: 0,
100                resource: inputs.x_buffer.as_entire_binding(),
101            },
102            wgpu::BindGroupEntry {
103                binding: 1,
104                resource: inputs.y_buffer.as_entire_binding(),
105            },
106            wgpu::BindGroupEntry {
107                binding: 2,
108                resource: output_buffer.as_entire_binding(),
109            },
110            wgpu::BindGroupEntry {
111                binding: 3,
112                resource: uniform_buffer.as_entire_binding(),
113            },
114        ],
115    });
116    let mut encoder = device.create_command_encoder(&wgpu::CommandEncoderDescriptor {
117        label: Some("stem-pack-encoder"),
118    });
119    {
120        let mut pass = encoder.begin_compute_pass(&wgpu::ComputePassDescriptor {
121            label: Some("stem-pack-pass"),
122            timestamp_writes: None,
123        });
124        pass.set_pipeline(&pipeline);
125        pass.set_bind_group(0, &bind_group, &[]);
126        pass.dispatch_workgroups(inputs.len.div_ceil(workgroup_size), 1, 1);
127    }
128    queue.submit(Some(encoder.finish()));
129    Ok(GpuVertexBuffer::new(output_buffer, vertex_count as usize))
130}
131
132fn compile_shader(
133    device: &Arc<wgpu::Device>,
134    workgroup_size: u32,
135    scalar: ScalarType,
136) -> wgpu::ShaderModule {
137    let template = match scalar {
138        ScalarType::F32 => shaders::stem::F32,
139        ScalarType::F64 => shaders::stem::F64,
140    };
141    let source = template.replace("{{WORKGROUP_SIZE}}", &workgroup_size.to_string());
142    device.create_shader_module(wgpu::ShaderModuleDescriptor {
143        label: Some("stem-pack-shader"),
144        source: wgpu::ShaderSource::Wgsl(source.into()),
145    })
146}
147
148fn storage_entry(binding: u32, read_only: bool) -> wgpu::BindGroupLayoutEntry {
149    wgpu::BindGroupLayoutEntry {
150        binding,
151        visibility: wgpu::ShaderStages::COMPUTE,
152        ty: wgpu::BindingType::Buffer {
153            ty: wgpu::BufferBindingType::Storage { read_only },
154            has_dynamic_offset: false,
155            min_binding_size: None,
156        },
157        count: None,
158    }
159}
160fn uniform_entry(binding: u32) -> wgpu::BindGroupLayoutEntry {
161    wgpu::BindGroupLayoutEntry {
162        binding,
163        visibility: wgpu::ShaderStages::COMPUTE,
164        ty: wgpu::BindingType::Buffer {
165            ty: wgpu::BufferBindingType::Uniform,
166            has_dynamic_offset: false,
167            min_binding_size: None,
168        },
169        count: None,
170    }
171}
172fn line_style_code(style: LineStyle) -> u32 {
173    match style {
174        LineStyle::Solid => 0,
175        LineStyle::Dashed => 1,
176        LineStyle::Dotted => 2,
177        LineStyle::DashDot => 3,
178    }
179}