use glam::DVec2;
use iced::wgpu::{util::DeviceExt, *};
use crate::plot_state::PlotState;
pub(crate) struct Grid {
pipeline: Option<RenderPipeline>,
vertex_buffer: Option<Buffer>,
vertex_count: u32,
last_center: DVec2,
last_extents: DVec2,
}
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
pub enum TickWeight {
Major,
Minor,
SubMinor,
}
impl Grid {
pub(crate) fn ensure_pipeline(
&mut self,
device: &Device,
format: TextureFormat,
camera_bgl: &BindGroupLayout,
) {
if self.pipeline.is_some() {
return;
}
let shader = device.create_shader_module(include_wgsl!("shaders/grid.wgsl"));
let layout = device.create_pipeline_layout(&PipelineLayoutDescriptor {
label: Some("Grid Pipeline Layout"),
bind_group_layouts: &[camera_bgl],
push_constant_ranges: &[],
});
let pipeline = device.create_render_pipeline(&RenderPipelineDescriptor {
label: Some("Grid Pipeline"),
layout: Some(&layout),
vertex: VertexState {
module: &shader,
entry_point: Some("vs_main"),
compilation_options: PipelineCompilationOptions::default(),
buffers: &[VertexBufferLayout {
array_stride: (std::mem::size_of::<[f32; 2]>() + std::mem::size_of::<f32>())
as u64,
step_mode: VertexStepMode::Vertex,
attributes: &[
VertexAttribute {
offset: 0,
shader_location: 0,
format: VertexFormat::Float32x2,
},
VertexAttribute {
offset: std::mem::size_of::<[f32; 2]>() as u64,
shader_location: 1,
format: VertexFormat::Float32,
},
],
}],
},
fragment: Some(FragmentState {
module: &shader,
entry_point: Some("fs_main"),
compilation_options: PipelineCompilationOptions::default(),
targets: &[Some(ColorTargetState {
format,
blend: Some(BlendState::ALPHA_BLENDING),
write_mask: ColorWrites::ALL,
})],
}),
primitive: PrimitiveState {
topology: PrimitiveTopology::LineList,
strip_index_format: None,
front_face: FrontFace::Ccw,
cull_mode: None,
polygon_mode: PolygonMode::Fill,
unclipped_depth: false,
conservative: false,
},
depth_stencil: None,
multisample: MultisampleState {
count: 1,
mask: !0,
alpha_to_coverage_enabled: false,
},
multiview: None,
cache: None,
});
self.pipeline = Some(pipeline);
}
pub(crate) fn update(&mut self, device: &Device, state: &PlotState) {
const GRID_MAJOR_ALPHA: f32 = 0.45;
const GRID_MINOR_ALPHA: f32 = 0.28;
const GRID_SUB_MINOR_ALPHA: f32 = 0.10;
let camera = &state.camera;
if camera.position == self.last_center && camera.half_extents == self.last_extents {
return;
}
self.last_center = camera.position;
self.last_extents = camera.half_extents;
let render_center = camera.effective_position();
let min_x = render_center.x - camera.half_extents.x;
let max_x = render_center.x + camera.half_extents.x;
let min_y = render_center.y - camera.half_extents.y;
let max_y = render_center.y + camera.half_extents.y;
let mut verts = Vec::new();
let mut count = 0u32;
let width = state.bounds.width.max(1.0);
let height = state.bounds.height.max(1.0);
for positioned_tick in &state.x_ticks {
let ndc_x = (positioned_tick.screen_pos / width) as f64 * 2.0 - 1.0;
let render_x = render_center.x + ndc_x * camera.half_extents.x;
let alpha = match positioned_tick.tick.line_type {
TickWeight::Major => GRID_MAJOR_ALPHA,
TickWeight::Minor => GRID_MINOR_ALPHA,
TickWeight::SubMinor => GRID_SUB_MINOR_ALPHA,
};
verts.extend_from_slice(&[render_x as f32, min_y as f32, alpha]);
verts.extend_from_slice(&[render_x as f32, max_y as f32, alpha]);
count += 2;
}
for positioned_tick in &state.y_ticks {
let ndc_y = 1.0 - (positioned_tick.screen_pos / height) as f64 * 2.0;
let render_y = render_center.y + ndc_y * camera.half_extents.y;
let alpha = match positioned_tick.tick.line_type {
TickWeight::Major => GRID_MAJOR_ALPHA,
TickWeight::Minor => GRID_MINOR_ALPHA,
TickWeight::SubMinor => GRID_SUB_MINOR_ALPHA,
};
verts.extend_from_slice(&[min_x as f32, render_y as f32, alpha]);
verts.extend_from_slice(&[max_x as f32, render_y as f32, alpha]);
count += 2;
}
self.vertex_count = count;
self.vertex_buffer = Some(device.create_buffer_init(&util::BufferInitDescriptor {
label: Some("Grid VB"),
contents: bytemuck::cast_slice(&verts),
usage: BufferUsages::VERTEX,
}));
}
pub(crate) fn draw<'a>(&'a self, pass: &mut RenderPass<'a>, camera_bind_group: &'a BindGroup) {
if self.vertex_count == 0 {
return;
}
if let (Some(pipeline), Some(vb)) = (&self.pipeline, &self.vertex_buffer) {
pass.set_pipeline(pipeline);
pass.set_bind_group(0, camera_bind_group, &[]);
pass.set_vertex_buffer(0, vb.slice(..));
pass.draw(0..self.vertex_count, 0..1);
}
}
}
impl Default for Grid {
fn default() -> Self {
Self {
pipeline: None,
vertex_buffer: None,
vertex_count: 0,
last_center: DVec2::splat(f64::NAN),
last_extents: DVec2::splat(f64::NAN),
}
}
}