1use crate::core::renderer::Vertex;
2use crate::core::scene::GpuVertexBuffer;
3use crate::gpu::shaders;
4use crate::gpu::{tuning, ScalarType};
5use crate::plots::line::LineStyle;
6use glam::Vec4;
7use std::sync::Arc;
8use wgpu::util::DeviceExt;
9
10#[derive(Clone, Debug)]
11pub struct StemGpuInputs {
12 pub x_buffer: Arc<wgpu::Buffer>,
13 pub y_buffer: Arc<wgpu::Buffer>,
14 pub len: u32,
15 pub scalar: ScalarType,
16}
17
18pub struct StemGpuParams {
19 pub color: Vec4,
20 pub baseline_color: Vec4,
21 pub baseline: f32,
22 pub baseline_visible: bool,
23 pub min_x: f32,
24 pub max_x: f32,
25 pub line_style: LineStyle,
26}
27
28#[repr(C)]
29#[derive(Clone, Copy, bytemuck::Pod, bytemuck::Zeroable)]
30struct StemUniforms {
31 color: [f32; 4],
32 baseline_color: [f32; 4],
33 baseline: f32,
34 min_x: f32,
35 max_x: f32,
36 point_count: u32,
37 line_style: u32,
38 baseline_visible: u32,
39}
40
41pub fn pack_vertices_from_xy(
42 device: &Arc<wgpu::Device>,
43 queue: &Arc<wgpu::Queue>,
44 inputs: &StemGpuInputs,
45 params: &StemGpuParams,
46) -> Result<GpuVertexBuffer, String> {
47 let workgroup_size = tuning::effective_workgroup_size();
48 let shader = compile_shader(device, workgroup_size, inputs.scalar);
49 let bind_group_layout = device.create_bind_group_layout(&wgpu::BindGroupLayoutDescriptor {
50 label: Some("stem-pack-bind-layout"),
51 entries: &[
52 storage_entry(0, true),
53 storage_entry(1, true),
54 storage_entry(2, false),
55 uniform_entry(3),
56 ],
57 });
58 let pipeline_layout = device.create_pipeline_layout(&wgpu::PipelineLayoutDescriptor {
59 label: Some("stem-pack-pipeline-layout"),
60 bind_group_layouts: &[&bind_group_layout],
61 push_constant_ranges: &[],
62 });
63 let pipeline =
64 device.create_compute_pipeline(&crate::wgpu_compat::wgpu_compute_pipeline_descriptor! {
65 label: Some("stem-pack-pipeline"),
66 layout: Some(&pipeline_layout),
67 module: &shader,
68 entry_point: "main",
69 });
70 let baseline_count = if params.baseline_visible { 2 } else { 0 };
71 let vertex_count = baseline_count as u64 + inputs.len as u64 * 2;
72 let output_buffer = Arc::new(device.create_buffer(&wgpu::BufferDescriptor {
73 label: Some("stem-gpu-vertices"),
74 size: vertex_count * std::mem::size_of::<Vertex>() as u64,
75 usage: wgpu::BufferUsages::STORAGE
76 | wgpu::BufferUsages::VERTEX
77 | wgpu::BufferUsages::COPY_DST
78 | wgpu::BufferUsages::COPY_SRC,
79 mapped_at_creation: false,
80 }));
81 let uniforms = StemUniforms {
82 color: params.color.to_array(),
83 baseline_color: params.baseline_color.to_array(),
84 baseline: params.baseline,
85 min_x: params.min_x,
86 max_x: params.max_x,
87 point_count: inputs.len,
88 line_style: line_style_code(params.line_style),
89 baseline_visible: if params.baseline_visible { 1 } else { 0 },
90 };
91 let uniform_buffer = device.create_buffer_init(&wgpu::util::BufferInitDescriptor {
92 label: Some("stem-pack-uniforms"),
93 contents: bytemuck::bytes_of(&uniforms),
94 usage: wgpu::BufferUsages::UNIFORM | wgpu::BufferUsages::COPY_DST,
95 });
96 let bind_group = device.create_bind_group(&wgpu::BindGroupDescriptor {
97 label: Some("stem-pack-bind-group"),
98 layout: &bind_group_layout,
99 entries: &[
100 wgpu::BindGroupEntry {
101 binding: 0,
102 resource: inputs.x_buffer.as_entire_binding(),
103 },
104 wgpu::BindGroupEntry {
105 binding: 1,
106 resource: inputs.y_buffer.as_entire_binding(),
107 },
108 wgpu::BindGroupEntry {
109 binding: 2,
110 resource: output_buffer.as_entire_binding(),
111 },
112 wgpu::BindGroupEntry {
113 binding: 3,
114 resource: uniform_buffer.as_entire_binding(),
115 },
116 ],
117 });
118 let mut encoder = device.create_command_encoder(&wgpu::CommandEncoderDescriptor {
119 label: Some("stem-pack-encoder"),
120 });
121 {
122 let mut pass = encoder.begin_compute_pass(&wgpu::ComputePassDescriptor {
123 label: Some("stem-pack-pass"),
124 timestamp_writes: None,
125 });
126 pass.set_pipeline(&pipeline);
127 pass.set_bind_group(0, &bind_group, &[]);
128 pass.dispatch_workgroups(inputs.len.div_ceil(workgroup_size), 1, 1);
129 }
130 queue.submit(Some(encoder.finish()));
131 Ok(GpuVertexBuffer::new(output_buffer, vertex_count as usize))
132}
133
134fn compile_shader(
135 device: &Arc<wgpu::Device>,
136 workgroup_size: u32,
137 scalar: ScalarType,
138) -> wgpu::ShaderModule {
139 let template = match scalar {
140 ScalarType::F32 => shaders::stem::F32,
141 ScalarType::F64 => shaders::stem::F64,
142 };
143 let source = template.replace("{{WORKGROUP_SIZE}}", &workgroup_size.to_string());
144 device.create_shader_module(wgpu::ShaderModuleDescriptor {
145 label: Some("stem-pack-shader"),
146 source: wgpu::ShaderSource::Wgsl(source.into()),
147 })
148}
149
150fn storage_entry(binding: u32, read_only: bool) -> wgpu::BindGroupLayoutEntry {
151 wgpu::BindGroupLayoutEntry {
152 binding,
153 visibility: wgpu::ShaderStages::COMPUTE,
154 ty: wgpu::BindingType::Buffer {
155 ty: wgpu::BufferBindingType::Storage { read_only },
156 has_dynamic_offset: false,
157 min_binding_size: None,
158 },
159 count: None,
160 }
161}
162fn uniform_entry(binding: u32) -> wgpu::BindGroupLayoutEntry {
163 wgpu::BindGroupLayoutEntry {
164 binding,
165 visibility: wgpu::ShaderStages::COMPUTE,
166 ty: wgpu::BindingType::Buffer {
167 ty: wgpu::BufferBindingType::Uniform,
168 has_dynamic_offset: false,
169 min_binding_size: None,
170 },
171 count: None,
172 }
173}
174fn line_style_code(style: LineStyle) -> u32 {
175 match style {
176 LineStyle::Solid => 0,
177 LineStyle::Dashed => 1,
178 LineStyle::Dotted => 2,
179 LineStyle::DashDot => 3,
180 }
181}