1use crate::core::renderer::Vertex;
2use crate::core::scene::GpuVertexBuffer;
3use crate::gpu::shaders;
4use crate::gpu::{tuning, ScalarType};
5use crate::plots::line::LineStyle;
6use glam::Vec4;
7use std::sync::Arc;
8use wgpu::util::DeviceExt;
9
10pub struct StemGpuInputs {
11 pub x_buffer: Arc<wgpu::Buffer>,
12 pub y_buffer: Arc<wgpu::Buffer>,
13 pub len: u32,
14 pub scalar: ScalarType,
15}
16
17pub struct StemGpuParams {
18 pub color: Vec4,
19 pub baseline_color: Vec4,
20 pub baseline: f32,
21 pub baseline_visible: bool,
22 pub min_x: f32,
23 pub max_x: f32,
24 pub line_style: LineStyle,
25}
26
27#[repr(C)]
28#[derive(Clone, Copy, bytemuck::Pod, bytemuck::Zeroable)]
29struct StemUniforms {
30 color: [f32; 4],
31 baseline_color: [f32; 4],
32 baseline: f32,
33 min_x: f32,
34 max_x: f32,
35 point_count: u32,
36 line_style: u32,
37 baseline_visible: u32,
38}
39
40pub fn pack_vertices_from_xy(
41 device: &Arc<wgpu::Device>,
42 queue: &Arc<wgpu::Queue>,
43 inputs: &StemGpuInputs,
44 params: &StemGpuParams,
45) -> Result<GpuVertexBuffer, String> {
46 let workgroup_size = tuning::effective_workgroup_size();
47 let shader = compile_shader(device, workgroup_size, inputs.scalar);
48 let bind_group_layout = device.create_bind_group_layout(&wgpu::BindGroupLayoutDescriptor {
49 label: Some("stem-pack-bind-layout"),
50 entries: &[
51 storage_entry(0, true),
52 storage_entry(1, true),
53 storage_entry(2, false),
54 uniform_entry(3),
55 ],
56 });
57 let pipeline_layout = device.create_pipeline_layout(&wgpu::PipelineLayoutDescriptor {
58 label: Some("stem-pack-pipeline-layout"),
59 bind_group_layouts: &[&bind_group_layout],
60 push_constant_ranges: &[],
61 });
62 let pipeline = device.create_compute_pipeline(&wgpu::ComputePipelineDescriptor {
63 label: Some("stem-pack-pipeline"),
64 layout: Some(&pipeline_layout),
65 module: &shader,
66 entry_point: "main",
67 });
68 let baseline_count = if params.baseline_visible { 2 } else { 0 };
69 let vertex_count = baseline_count as u64 + inputs.len as u64 * 2;
70 let output_buffer = Arc::new(device.create_buffer(&wgpu::BufferDescriptor {
71 label: Some("stem-gpu-vertices"),
72 size: vertex_count * std::mem::size_of::<Vertex>() as u64,
73 usage: wgpu::BufferUsages::STORAGE
74 | wgpu::BufferUsages::VERTEX
75 | wgpu::BufferUsages::COPY_DST,
76 mapped_at_creation: false,
77 }));
78 let uniforms = StemUniforms {
79 color: params.color.to_array(),
80 baseline_color: params.baseline_color.to_array(),
81 baseline: params.baseline,
82 min_x: params.min_x,
83 max_x: params.max_x,
84 point_count: inputs.len,
85 line_style: line_style_code(params.line_style),
86 baseline_visible: if params.baseline_visible { 1 } else { 0 },
87 };
88 let uniform_buffer = device.create_buffer_init(&wgpu::util::BufferInitDescriptor {
89 label: Some("stem-pack-uniforms"),
90 contents: bytemuck::bytes_of(&uniforms),
91 usage: wgpu::BufferUsages::UNIFORM | wgpu::BufferUsages::COPY_DST,
92 });
93 let bind_group = device.create_bind_group(&wgpu::BindGroupDescriptor {
94 label: Some("stem-pack-bind-group"),
95 layout: &bind_group_layout,
96 entries: &[
97 wgpu::BindGroupEntry {
98 binding: 0,
99 resource: inputs.x_buffer.as_entire_binding(),
100 },
101 wgpu::BindGroupEntry {
102 binding: 1,
103 resource: inputs.y_buffer.as_entire_binding(),
104 },
105 wgpu::BindGroupEntry {
106 binding: 2,
107 resource: output_buffer.as_entire_binding(),
108 },
109 wgpu::BindGroupEntry {
110 binding: 3,
111 resource: uniform_buffer.as_entire_binding(),
112 },
113 ],
114 });
115 let mut encoder = device.create_command_encoder(&wgpu::CommandEncoderDescriptor {
116 label: Some("stem-pack-encoder"),
117 });
118 {
119 let mut pass = encoder.begin_compute_pass(&wgpu::ComputePassDescriptor {
120 label: Some("stem-pack-pass"),
121 timestamp_writes: None,
122 });
123 pass.set_pipeline(&pipeline);
124 pass.set_bind_group(0, &bind_group, &[]);
125 pass.dispatch_workgroups(inputs.len.div_ceil(workgroup_size), 1, 1);
126 }
127 queue.submit(Some(encoder.finish()));
128 Ok(GpuVertexBuffer::new(output_buffer, vertex_count as usize))
129}
130
131fn compile_shader(
132 device: &Arc<wgpu::Device>,
133 workgroup_size: u32,
134 scalar: ScalarType,
135) -> wgpu::ShaderModule {
136 let template = match scalar {
137 ScalarType::F32 => shaders::stem::F32,
138 ScalarType::F64 => shaders::stem::F64,
139 };
140 let source = template.replace("{{WORKGROUP_SIZE}}", &workgroup_size.to_string());
141 device.create_shader_module(wgpu::ShaderModuleDescriptor {
142 label: Some("stem-pack-shader"),
143 source: wgpu::ShaderSource::Wgsl(source.into()),
144 })
145}
146
147fn storage_entry(binding: u32, read_only: bool) -> wgpu::BindGroupLayoutEntry {
148 wgpu::BindGroupLayoutEntry {
149 binding,
150 visibility: wgpu::ShaderStages::COMPUTE,
151 ty: wgpu::BindingType::Buffer {
152 ty: wgpu::BufferBindingType::Storage { read_only },
153 has_dynamic_offset: false,
154 min_binding_size: None,
155 },
156 count: None,
157 }
158}
159fn uniform_entry(binding: u32) -> wgpu::BindGroupLayoutEntry {
160 wgpu::BindGroupLayoutEntry {
161 binding,
162 visibility: wgpu::ShaderStages::COMPUTE,
163 ty: wgpu::BindingType::Buffer {
164 ty: wgpu::BufferBindingType::Uniform,
165 has_dynamic_offset: false,
166 min_binding_size: None,
167 },
168 count: None,
169 }
170}
171fn line_style_code(style: LineStyle) -> u32 {
172 match style {
173 LineStyle::Solid => 0,
174 LineStyle::Dashed => 1,
175 LineStyle::Dotted => 2,
176 LineStyle::DashDot => 3,
177 }
178}