Skip to main content

runmat_plot/gpu/
errorbar.rs

1use crate::core::renderer::Vertex;
2use crate::core::scene::GpuVertexBuffer;
3use crate::gpu::shaders;
4use crate::gpu::{tuning, ScalarType};
5use crate::plots::line::LineStyle;
6use glam::Vec4;
7use std::sync::Arc;
8use wgpu::util::DeviceExt;
9
10pub struct ErrorBarGpuInputs {
11    pub x_buffer: Arc<wgpu::Buffer>,
12    pub y_buffer: Arc<wgpu::Buffer>,
13    pub x_neg_buffer: Option<Arc<wgpu::Buffer>>,
14    pub x_pos_buffer: Option<Arc<wgpu::Buffer>>,
15    pub y_neg_buffer: Arc<wgpu::Buffer>,
16    pub y_pos_buffer: Arc<wgpu::Buffer>,
17    pub len: u32,
18    pub scalar: ScalarType,
19}
20
21pub struct ErrorBarGpuParams {
22    pub color: Vec4,
23    pub cap_size_data: f32,
24    pub line_style: LineStyle,
25    pub orientation: u32,
26}
27
28#[repr(C)]
29#[derive(Clone, Copy, bytemuck::Pod, bytemuck::Zeroable)]
30struct ErrorBarUniforms {
31    color: [f32; 4],
32    count: u32,
33    line_style: u32,
34    cap_half_width: f32,
35    orientation: u32,
36}
37
38pub fn pack_vertical_vertices(
39    device: &Arc<wgpu::Device>,
40    queue: &Arc<wgpu::Queue>,
41    inputs: &ErrorBarGpuInputs,
42    params: &ErrorBarGpuParams,
43) -> Result<GpuVertexBuffer, String> {
44    let workgroup_size = tuning::effective_workgroup_size();
45    let shader = compile_shader(device, workgroup_size, inputs.scalar);
46    let bind_group_layout = device.create_bind_group_layout(&wgpu::BindGroupLayoutDescriptor {
47        label: Some("errorbar-pack-bind-layout"),
48        entries: &[
49            storage_entry(0, true),
50            storage_entry(1, true),
51            storage_entry(2, true),
52            storage_entry(3, true),
53            storage_entry(4, true),
54            storage_entry(5, true),
55            storage_entry(6, false),
56            uniform_entry(7),
57        ],
58    });
59    let pipeline_layout = device.create_pipeline_layout(&wgpu::PipelineLayoutDescriptor {
60        label: Some("errorbar-pack-pipeline-layout"),
61        bind_group_layouts: &[&bind_group_layout],
62        push_constant_ranges: &[],
63    });
64    let pipeline = device.create_compute_pipeline(&wgpu::ComputePipelineDescriptor {
65        label: Some("errorbar-pack-pipeline"),
66        layout: Some(&pipeline_layout),
67        module: &shader,
68        entry_point: "main",
69    });
70    let segments_per_point = match params.orientation {
71        0 => 6u64,
72        1 => 6u64,
73        _ => 12u64,
74    };
75    let max_vertices = inputs.len as u64 * segments_per_point;
76    let output_buffer = Arc::new(device.create_buffer(&wgpu::BufferDescriptor {
77        label: Some("errorbar-gpu-vertices"),
78        size: max_vertices * std::mem::size_of::<Vertex>() as u64,
79        usage: wgpu::BufferUsages::STORAGE
80            | wgpu::BufferUsages::VERTEX
81            | wgpu::BufferUsages::COPY_DST,
82        mapped_at_creation: false,
83    }));
84    let uniforms = ErrorBarUniforms {
85        color: params.color.to_array(),
86        count: inputs.len,
87        line_style: line_style_code(params.line_style),
88        cap_half_width: params.cap_size_data.max(0.0) * 0.5,
89        orientation: params.orientation,
90    };
91    let uniform_buffer = device.create_buffer_init(&wgpu::util::BufferInitDescriptor {
92        label: Some("errorbar-pack-uniforms"),
93        contents: bytemuck::bytes_of(&uniforms),
94        usage: wgpu::BufferUsages::UNIFORM | wgpu::BufferUsages::COPY_DST,
95    });
96    let bind_group = device.create_bind_group(&wgpu::BindGroupDescriptor {
97        label: Some("errorbar-pack-bind-group"),
98        layout: &bind_group_layout,
99        entries: &[
100            wgpu::BindGroupEntry {
101                binding: 0,
102                resource: inputs.x_buffer.as_entire_binding(),
103            },
104            wgpu::BindGroupEntry {
105                binding: 1,
106                resource: inputs.y_buffer.as_entire_binding(),
107            },
108            wgpu::BindGroupEntry {
109                binding: 2,
110                resource: inputs
111                    .x_neg_buffer
112                    .as_ref()
113                    .map(|b| b.as_entire_binding())
114                    .unwrap_or(inputs.x_buffer.as_entire_binding()),
115            },
116            wgpu::BindGroupEntry {
117                binding: 3,
118                resource: inputs
119                    .x_pos_buffer
120                    .as_ref()
121                    .map(|b| b.as_entire_binding())
122                    .unwrap_or(inputs.x_buffer.as_entire_binding()),
123            },
124            wgpu::BindGroupEntry {
125                binding: 4,
126                resource: inputs.y_neg_buffer.as_entire_binding(),
127            },
128            wgpu::BindGroupEntry {
129                binding: 5,
130                resource: inputs.y_pos_buffer.as_entire_binding(),
131            },
132            wgpu::BindGroupEntry {
133                binding: 6,
134                resource: output_buffer.as_entire_binding(),
135            },
136            wgpu::BindGroupEntry {
137                binding: 7,
138                resource: uniform_buffer.as_entire_binding(),
139            },
140        ],
141    });
142    let mut encoder = device.create_command_encoder(&wgpu::CommandEncoderDescriptor {
143        label: Some("errorbar-pack-encoder"),
144    });
145    {
146        let mut pass = encoder.begin_compute_pass(&wgpu::ComputePassDescriptor {
147            label: Some("errorbar-pack-pass"),
148            timestamp_writes: None,
149        });
150        pass.set_pipeline(&pipeline);
151        pass.set_bind_group(0, &bind_group, &[]);
152        pass.dispatch_workgroups(inputs.len.div_ceil(workgroup_size), 1, 1);
153    }
154    queue.submit(Some(encoder.finish()));
155    Ok(GpuVertexBuffer::new(output_buffer, max_vertices as usize))
156}
157
158fn compile_shader(
159    device: &Arc<wgpu::Device>,
160    workgroup_size: u32,
161    scalar: ScalarType,
162) -> wgpu::ShaderModule {
163    let template = match scalar {
164        ScalarType::F32 => shaders::errorbar::F32,
165        ScalarType::F64 => shaders::errorbar::F64,
166    };
167    let source = template.replace("{{WORKGROUP_SIZE}}", &workgroup_size.to_string());
168    device.create_shader_module(wgpu::ShaderModuleDescriptor {
169        label: Some("errorbar-pack-shader"),
170        source: wgpu::ShaderSource::Wgsl(source.into()),
171    })
172}
173
174fn line_style_code(style: LineStyle) -> u32 {
175    match style {
176        LineStyle::Solid => 0,
177        LineStyle::Dashed => 1,
178        LineStyle::Dotted => 2,
179        LineStyle::DashDot => 3,
180    }
181}
182fn storage_entry(binding: u32, read_only: bool) -> wgpu::BindGroupLayoutEntry {
183    wgpu::BindGroupLayoutEntry {
184        binding,
185        visibility: wgpu::ShaderStages::COMPUTE,
186        ty: wgpu::BindingType::Buffer {
187            ty: wgpu::BufferBindingType::Storage { read_only },
188            has_dynamic_offset: false,
189            min_binding_size: None,
190        },
191        count: None,
192    }
193}
194fn uniform_entry(binding: u32) -> wgpu::BindGroupLayoutEntry {
195    wgpu::BindGroupLayoutEntry {
196        binding,
197        visibility: wgpu::ShaderStages::COMPUTE,
198        ty: wgpu::BindingType::Buffer {
199            ty: wgpu::BufferBindingType::Uniform,
200            has_dynamic_offset: false,
201            min_binding_size: None,
202        },
203        count: None,
204    }
205}