Skip to main content

runmat_plot/gpu/
stem.rs

1use crate::core::renderer::Vertex;
2use crate::core::scene::GpuVertexBuffer;
3use crate::gpu::shaders;
4use crate::gpu::{tuning, ScalarType};
5use crate::plots::line::LineStyle;
6use glam::Vec4;
7use std::sync::Arc;
8use wgpu::util::DeviceExt;
9
10pub struct StemGpuInputs {
11    pub x_buffer: Arc<wgpu::Buffer>,
12    pub y_buffer: Arc<wgpu::Buffer>,
13    pub len: u32,
14    pub scalar: ScalarType,
15}
16
17pub struct StemGpuParams {
18    pub color: Vec4,
19    pub baseline_color: Vec4,
20    pub baseline: f32,
21    pub baseline_visible: bool,
22    pub min_x: f32,
23    pub max_x: f32,
24    pub line_style: LineStyle,
25}
26
27#[repr(C)]
28#[derive(Clone, Copy, bytemuck::Pod, bytemuck::Zeroable)]
29struct StemUniforms {
30    color: [f32; 4],
31    baseline_color: [f32; 4],
32    baseline: f32,
33    min_x: f32,
34    max_x: f32,
35    point_count: u32,
36    line_style: u32,
37    baseline_visible: u32,
38}
39
40pub fn pack_vertices_from_xy(
41    device: &Arc<wgpu::Device>,
42    queue: &Arc<wgpu::Queue>,
43    inputs: &StemGpuInputs,
44    params: &StemGpuParams,
45) -> Result<GpuVertexBuffer, String> {
46    let workgroup_size = tuning::effective_workgroup_size();
47    let shader = compile_shader(device, workgroup_size, inputs.scalar);
48    let bind_group_layout = device.create_bind_group_layout(&wgpu::BindGroupLayoutDescriptor {
49        label: Some("stem-pack-bind-layout"),
50        entries: &[
51            storage_entry(0, true),
52            storage_entry(1, true),
53            storage_entry(2, false),
54            uniform_entry(3),
55        ],
56    });
57    let pipeline_layout = device.create_pipeline_layout(&wgpu::PipelineLayoutDescriptor {
58        label: Some("stem-pack-pipeline-layout"),
59        bind_group_layouts: &[&bind_group_layout],
60        push_constant_ranges: &[],
61    });
62    let pipeline = device.create_compute_pipeline(&wgpu::ComputePipelineDescriptor {
63        label: Some("stem-pack-pipeline"),
64        layout: Some(&pipeline_layout),
65        module: &shader,
66        entry_point: "main",
67    });
68    let baseline_count = if params.baseline_visible { 2 } else { 0 };
69    let vertex_count = baseline_count as u64 + inputs.len as u64 * 2;
70    let output_buffer = Arc::new(device.create_buffer(&wgpu::BufferDescriptor {
71        label: Some("stem-gpu-vertices"),
72        size: vertex_count * std::mem::size_of::<Vertex>() as u64,
73        usage: wgpu::BufferUsages::STORAGE
74            | wgpu::BufferUsages::VERTEX
75            | wgpu::BufferUsages::COPY_DST,
76        mapped_at_creation: false,
77    }));
78    let uniforms = StemUniforms {
79        color: params.color.to_array(),
80        baseline_color: params.baseline_color.to_array(),
81        baseline: params.baseline,
82        min_x: params.min_x,
83        max_x: params.max_x,
84        point_count: inputs.len,
85        line_style: line_style_code(params.line_style),
86        baseline_visible: if params.baseline_visible { 1 } else { 0 },
87    };
88    let uniform_buffer = device.create_buffer_init(&wgpu::util::BufferInitDescriptor {
89        label: Some("stem-pack-uniforms"),
90        contents: bytemuck::bytes_of(&uniforms),
91        usage: wgpu::BufferUsages::UNIFORM | wgpu::BufferUsages::COPY_DST,
92    });
93    let bind_group = device.create_bind_group(&wgpu::BindGroupDescriptor {
94        label: Some("stem-pack-bind-group"),
95        layout: &bind_group_layout,
96        entries: &[
97            wgpu::BindGroupEntry {
98                binding: 0,
99                resource: inputs.x_buffer.as_entire_binding(),
100            },
101            wgpu::BindGroupEntry {
102                binding: 1,
103                resource: inputs.y_buffer.as_entire_binding(),
104            },
105            wgpu::BindGroupEntry {
106                binding: 2,
107                resource: output_buffer.as_entire_binding(),
108            },
109            wgpu::BindGroupEntry {
110                binding: 3,
111                resource: uniform_buffer.as_entire_binding(),
112            },
113        ],
114    });
115    let mut encoder = device.create_command_encoder(&wgpu::CommandEncoderDescriptor {
116        label: Some("stem-pack-encoder"),
117    });
118    {
119        let mut pass = encoder.begin_compute_pass(&wgpu::ComputePassDescriptor {
120            label: Some("stem-pack-pass"),
121            timestamp_writes: None,
122        });
123        pass.set_pipeline(&pipeline);
124        pass.set_bind_group(0, &bind_group, &[]);
125        pass.dispatch_workgroups(inputs.len.div_ceil(workgroup_size), 1, 1);
126    }
127    queue.submit(Some(encoder.finish()));
128    Ok(GpuVertexBuffer::new(output_buffer, vertex_count as usize))
129}
130
131fn compile_shader(
132    device: &Arc<wgpu::Device>,
133    workgroup_size: u32,
134    scalar: ScalarType,
135) -> wgpu::ShaderModule {
136    let template = match scalar {
137        ScalarType::F32 => shaders::stem::F32,
138        ScalarType::F64 => shaders::stem::F64,
139    };
140    let source = template.replace("{{WORKGROUP_SIZE}}", &workgroup_size.to_string());
141    device.create_shader_module(wgpu::ShaderModuleDescriptor {
142        label: Some("stem-pack-shader"),
143        source: wgpu::ShaderSource::Wgsl(source.into()),
144    })
145}
146
147fn storage_entry(binding: u32, read_only: bool) -> wgpu::BindGroupLayoutEntry {
148    wgpu::BindGroupLayoutEntry {
149        binding,
150        visibility: wgpu::ShaderStages::COMPUTE,
151        ty: wgpu::BindingType::Buffer {
152            ty: wgpu::BufferBindingType::Storage { read_only },
153            has_dynamic_offset: false,
154            min_binding_size: None,
155        },
156        count: None,
157    }
158}
159fn uniform_entry(binding: u32) -> wgpu::BindGroupLayoutEntry {
160    wgpu::BindGroupLayoutEntry {
161        binding,
162        visibility: wgpu::ShaderStages::COMPUTE,
163        ty: wgpu::BindingType::Buffer {
164            ty: wgpu::BufferBindingType::Uniform,
165            has_dynamic_offset: false,
166            min_binding_size: None,
167        },
168        count: None,
169    }
170}
171fn line_style_code(style: LineStyle) -> u32 {
172    match style {
173        LineStyle::Solid => 0,
174        LineStyle::Dashed => 1,
175        LineStyle::Dotted => 2,
176        LineStyle::DashDot => 3,
177    }
178}