Skip to main content

engvis_renderer/
grid_renderer.rs

1use wgpu::util::DeviceExt;
2
3/// Vertex for grid/axis lines (position + color)
4#[repr(C)]
5#[derive(Debug, Clone, Copy, bytemuck::Pod, bytemuck::Zeroable)]
6struct GridVertex {
7    position: [f32; 3],
8    color: [f32; 4],
9}
10
11pub struct GridRenderer {
12    pub vertex_buffer: wgpu::Buffer,
13    pub vertex_count: u32,
14    pub pipeline: wgpu::RenderPipeline,
15    pub bind_group_layout: wgpu::BindGroupLayout,
16    pub bind_group: wgpu::BindGroup,
17    pub uniform_buffer: wgpu::Buffer,
18}
19
20impl GridRenderer {
21    pub fn new(
22        device: &wgpu::Device,
23        surface_format: wgpu::TextureFormat,
24        scene_layout: &wgpu::BindGroupLayout,
25    ) -> Self {
26        let (vertices, vertex_count) = Self::generate_grid_vertices();
27
28        let vertex_buffer = device.create_buffer_init(&wgpu::util::BufferInitDescriptor {
29            label: Some("Grid Vertex Buffer"),
30            contents: bytemuck::cast_slice(&vertices),
31            usage: wgpu::BufferUsages::VERTEX,
32        });
33
34        // Simple bind group for grid: just scene uniforms (group 0)
35        // and an identity model matrix uniform (group 1)
36        let bind_group_layout = device.create_bind_group_layout(&wgpu::BindGroupLayoutDescriptor {
37            label: Some("Grid Object Bind Group Layout"),
38            entries: &[wgpu::BindGroupLayoutEntry {
39                binding: 0,
40                visibility: wgpu::ShaderStages::VERTEX,
41                ty: wgpu::BindingType::Buffer {
42                    ty: wgpu::BufferBindingType::Uniform,
43                    has_dynamic_offset: false,
44                    min_binding_size: None,
45                },
46                count: None,
47            }],
48        });
49
50        // Identity model matrix
51        let identity_matrix: [[f32; 4]; 4] = [
52            [1.0, 0.0, 0.0, 0.0],
53            [0.0, 1.0, 0.0, 0.0],
54            [0.0, 0.0, 1.0, 0.0],
55            [0.0, 0.0, 0.0, 1.0],
56        ];
57        // Pack as { model: mat4, normal_matrix: mat4 } same as ObjectUniforms
58        let uniform_data: [[f32; 4]; 8] = [
59            identity_matrix[0],
60            identity_matrix[1],
61            identity_matrix[2],
62            identity_matrix[3],
63            identity_matrix[0],
64            identity_matrix[1],
65            identity_matrix[2],
66            identity_matrix[3],
67        ];
68
69        let uniform_buffer = device.create_buffer_init(&wgpu::util::BufferInitDescriptor {
70            label: Some("Grid Object Uniform Buffer"),
71            contents: bytemuck::cast_slice(&uniform_data),
72            usage: wgpu::BufferUsages::UNIFORM,
73        });
74
75        let bind_group = device.create_bind_group(&wgpu::BindGroupDescriptor {
76            label: Some("Grid Object Bind Group"),
77            layout: &bind_group_layout,
78            entries: &[wgpu::BindGroupEntry {
79                binding: 0,
80                resource: uniform_buffer.as_entire_binding(),
81            }],
82        });
83
84        let shader_source = Self::build_shader_source();
85        let shader_module = device.create_shader_module(wgpu::ShaderModuleDescriptor {
86            label: Some("Grid Shader"),
87            source: wgpu::ShaderSource::Wgsl(shader_source.into()),
88        });
89
90        let pipeline_layout = device.create_pipeline_layout(&wgpu::PipelineLayoutDescriptor {
91            label: Some("Grid Pipeline Layout"),
92            bind_group_layouts: &[scene_layout, &bind_group_layout],
93            push_constant_ranges: &[],
94        });
95
96        let pipeline = device.create_render_pipeline(&wgpu::RenderPipelineDescriptor {
97            label: Some("Grid Pipeline"),
98            layout: Some(&pipeline_layout),
99            vertex: wgpu::VertexState {
100                module: &shader_module,
101                entry_point: Some("vs_main"),
102                buffers: &[wgpu::VertexBufferLayout {
103                    array_stride: std::mem::size_of::<GridVertex>() as wgpu::BufferAddress,
104                    step_mode: wgpu::VertexStepMode::Vertex,
105                    attributes: &[
106                        wgpu::VertexAttribute {
107                            format: wgpu::VertexFormat::Float32x3,
108                            offset: 0,
109                            shader_location: 0,
110                        },
111                        wgpu::VertexAttribute {
112                            format: wgpu::VertexFormat::Float32x4,
113                            offset: 12,
114                            shader_location: 1,
115                        },
116                    ],
117                }],
118                compilation_options: wgpu::PipelineCompilationOptions::default(),
119            },
120            fragment: Some(wgpu::FragmentState {
121                module: &shader_module,
122                entry_point: Some("fs_main"),
123                targets: &[Some(wgpu::ColorTargetState {
124                    format: surface_format,
125                    blend: Some(wgpu::BlendState::ALPHA_BLENDING),
126                    write_mask: wgpu::ColorWrites::ALL,
127                })],
128                compilation_options: wgpu::PipelineCompilationOptions::default(),
129            }),
130            primitive: wgpu::PrimitiveState {
131                topology: wgpu::PrimitiveTopology::LineList,
132                strip_index_format: None,
133                front_face: wgpu::FrontFace::Ccw,
134                cull_mode: None,
135                polygon_mode: wgpu::PolygonMode::Fill,
136                unclipped_depth: false,
137                conservative: false,
138            },
139            depth_stencil: Some(wgpu::DepthStencilState {
140                format: crate::depth::DepthTexture::FORMAT,
141                depth_write_enabled: false,
142                depth_compare: wgpu::CompareFunction::LessEqual,
143                stencil: wgpu::StencilState::default(),
144                bias: wgpu::DepthBiasState::default(),
145            }),
146            multisample: wgpu::MultisampleState {
147                count: 4,
148                mask: !0,
149                alpha_to_coverage_enabled: false,
150            },
151            multiview: None,
152            cache: None,
153        });
154
155        Self {
156            vertex_buffer,
157            vertex_count,
158            pipeline,
159            bind_group_layout,
160            bind_group,
161            uniform_buffer,
162        }
163    }
164
165    pub fn render<'a>(&'a self, render_pass: &mut wgpu::RenderPass<'a>) {
166        render_pass.set_pipeline(&self.pipeline);
167        render_pass.set_bind_group(1, &self.bind_group, &[]);
168        render_pass.set_vertex_buffer(0, self.vertex_buffer.slice(..));
169        render_pass.draw(0..self.vertex_count, 0..1);
170    }
171
172    fn generate_grid_vertices() -> (Vec<GridVertex>, u32) {
173        let mut vertices = Vec::new();
174
175        let grid_half = 25;
176        let major_every = 5;
177
178        // Grid lines on XZ plane
179        for i in -grid_half..=grid_half {
180            let is_major = i % major_every == 0;
181            let alpha = if is_major { 0.4 } else { 0.15 };
182            let color = [0.5, 0.5, 0.5, alpha];
183
184            let fi = i as f32;
185            // Line along Z at x = fi
186            vertices.push(GridVertex {
187                position: [fi, 0.0, -grid_half as f32],
188                color,
189            });
190            vertices.push(GridVertex {
191                position: [fi, 0.0, grid_half as f32],
192                color,
193            });
194            // Line along X at z = fi
195            vertices.push(GridVertex {
196                position: [-grid_half as f32, 0.0, fi],
197                color,
198            });
199            vertices.push(GridVertex {
200                position: [grid_half as f32, 0.0, fi],
201                color,
202            });
203        }
204
205        // X axis (red)
206        vertices.push(GridVertex {
207            position: [0.0, 0.0, 0.0],
208            color: [1.0, 0.2, 0.2, 0.9],
209        });
210        vertices.push(GridVertex {
211            position: [grid_half as f32, 0.0, 0.0],
212            color: [1.0, 0.2, 0.2, 0.9],
213        });
214
215        // Y axis (green)
216        vertices.push(GridVertex {
217            position: [0.0, 0.0, 0.0],
218            color: [0.2, 1.0, 0.2, 0.9],
219        });
220        vertices.push(GridVertex {
221            position: [0.0, grid_half as f32, 0.0],
222            color: [0.2, 1.0, 0.2, 0.9],
223        });
224
225        // Z axis (blue)
226        vertices.push(GridVertex {
227            position: [0.0, 0.0, 0.0],
228            color: [0.3, 0.4, 1.0, 0.9],
229        });
230        vertices.push(GridVertex {
231            position: [0.0, 0.0, grid_half as f32],
232            color: [0.3, 0.4, 1.0, 0.9],
233        });
234
235        let count = vertices.len() as u32;
236        (vertices, count)
237    }
238
239    fn build_shader_source() -> String {
240        r#"
241struct SceneUniforms {
242    view_proj: mat4x4<f32>,
243    camera_pos: vec4<f32>,
244    viewport: vec4<f32>,
245    global_opacity: vec4<f32>,
246}
247
248struct ObjectUniforms {
249    model: mat4x4<f32>,
250    normal_matrix: mat4x4<f32>,
251}
252
253@group(0) @binding(0) var<uniform> scene: SceneUniforms;
254@group(1) @binding(0) var<uniform> object: ObjectUniforms;
255
256struct VertexInput {
257    @location(0) position: vec3<f32>,
258    @location(1) color: vec4<f32>,
259}
260
261struct VertexOutput {
262    @builtin(position) clip_pos: vec4<f32>,
263    @location(0) color: vec4<f32>,
264}
265
266@vertex
267fn vs_main(in: VertexInput) -> VertexOutput {
268    var out: VertexOutput;
269    let world_pos = (object.model * vec4<f32>(in.position, 1.0)).xyz;
270    out.clip_pos = scene.view_proj * vec4<f32>(world_pos, 1.0);
271    out.color = in.color;
272    return out;
273}
274
275@fragment
276fn fs_main(in: VertexOutput) -> @location(0) vec4<f32> {
277    return in.color;
278}
279"#
280        .to_string()
281    }
282}