1use bytemuck::{Pod, Zeroable};
43use wgpu::util::DeviceExt;
44
45use super::gpu::GpuContext;
46
47#[repr(C)]
49#[derive(Copy, Clone, Pod, Zeroable)]
50pub struct GeoVertex {
51 pub position: [f32; 2],
52 pub color: [f32; 4],
53}
54
55const MAX_VERTICES: usize = 65536;
58
59pub struct GeometryBatch {
60 pipeline: wgpu::RenderPipeline,
61 vertices: Vec<GeoVertex>,
62}
63
64impl GeometryBatch {
65 pub fn new(gpu: &GpuContext) -> Self {
70 let shader = gpu.device.create_shader_module(wgpu::ShaderModuleDescriptor {
71 label: Some("geom_shader"),
72 source: wgpu::ShaderSource::Wgsl(include_str!("shaders/geom.wgsl").into()),
73 });
74
75 let camera_bgl =
76 gpu.device
77 .create_bind_group_layout(&wgpu::BindGroupLayoutDescriptor {
78 label: Some("geom_camera_bind_group_layout"),
79 entries: &[wgpu::BindGroupLayoutEntry {
80 binding: 0,
81 visibility: wgpu::ShaderStages::VERTEX,
82 ty: wgpu::BindingType::Buffer {
83 ty: wgpu::BufferBindingType::Uniform,
84 has_dynamic_offset: false,
85 min_binding_size: None,
86 },
87 count: None,
88 }],
89 });
90
91 let pipeline_layout =
92 gpu.device
93 .create_pipeline_layout(&wgpu::PipelineLayoutDescriptor {
94 label: Some("geom_pipeline_layout"),
95 bind_group_layouts: &[&camera_bgl],
96 push_constant_ranges: &[],
97 });
98
99 let vertex_layout = wgpu::VertexBufferLayout {
100 array_stride: std::mem::size_of::<GeoVertex>() as wgpu::BufferAddress,
101 step_mode: wgpu::VertexStepMode::Vertex,
102 attributes: &[
103 wgpu::VertexAttribute {
105 offset: 0,
106 shader_location: 0,
107 format: wgpu::VertexFormat::Float32x2,
108 },
109 wgpu::VertexAttribute {
111 offset: 8,
112 shader_location: 1,
113 format: wgpu::VertexFormat::Float32x4,
114 },
115 ],
116 };
117
118 let pipeline = gpu.device.create_render_pipeline(&wgpu::RenderPipelineDescriptor {
119 label: Some("geom_pipeline"),
120 layout: Some(&pipeline_layout),
121 vertex: wgpu::VertexState {
122 module: &shader,
123 entry_point: Some("vs_main"),
124 buffers: &[vertex_layout],
125 compilation_options: Default::default(),
126 },
127 fragment: Some(wgpu::FragmentState {
128 module: &shader,
129 entry_point: Some("fs_main"),
130 targets: &[Some(wgpu::ColorTargetState {
131 format: gpu.config.format,
132 blend: Some(wgpu::BlendState::ALPHA_BLENDING),
133 write_mask: wgpu::ColorWrites::ALL,
134 })],
135 compilation_options: Default::default(),
136 }),
137 primitive: wgpu::PrimitiveState {
138 topology: wgpu::PrimitiveTopology::TriangleList,
139 strip_index_format: None,
140 front_face: wgpu::FrontFace::Ccw,
141 cull_mode: None,
142 polygon_mode: wgpu::PolygonMode::Fill,
143 unclipped_depth: false,
144 conservative: false,
145 },
146 depth_stencil: None,
147 multisample: wgpu::MultisampleState::default(),
148 multiview: None,
149 cache: None,
150 });
151
152 Self {
153 pipeline,
154 vertices: Vec::with_capacity(MAX_VERTICES),
155 }
156 }
157
158 pub fn add_triangle(
160 &mut self,
161 x1: f32, y1: f32,
162 x2: f32, y2: f32,
163 x3: f32, y3: f32,
164 r: f32, g: f32, b: f32, a: f32,
165 ) {
166 if self.vertices.len() + 3 > MAX_VERTICES {
167 return; }
169 let color = [r, g, b, a];
170 self.vertices.push(GeoVertex { position: [x1, y1], color });
171 self.vertices.push(GeoVertex { position: [x2, y2], color });
172 self.vertices.push(GeoVertex { position: [x3, y3], color });
173 }
174
175 pub fn add_line(
178 &mut self,
179 x1: f32, y1: f32,
180 x2: f32, y2: f32,
181 thickness: f32,
182 r: f32, g: f32, b: f32, a: f32,
183 ) {
184 if self.vertices.len() + 6 > MAX_VERTICES {
185 return;
186 }
187 let dx = x2 - x1;
188 let dy = y2 - y1;
189 let len = (dx * dx + dy * dy).sqrt();
190 if len < 1e-8 {
191 return; }
193 let half = thickness * 0.5;
195 let nx = -dy / len * half;
196 let ny = dx / len * half;
197
198 let color = [r, g, b, a];
199 let a0 = GeoVertex { position: [x1 + nx, y1 + ny], color };
201 let b0 = GeoVertex { position: [x1 - nx, y1 - ny], color };
202 let c0 = GeoVertex { position: [x2 - nx, y2 - ny], color };
203 let d0 = GeoVertex { position: [x2 + nx, y2 + ny], color };
204
205 self.vertices.push(a0);
207 self.vertices.push(b0);
208 self.vertices.push(c0);
209 self.vertices.push(a0);
210 self.vertices.push(c0);
211 self.vertices.push(d0);
212 }
213
214 pub fn flush(
217 &mut self,
218 gpu: &GpuContext,
219 encoder: &mut wgpu::CommandEncoder,
220 target: &wgpu::TextureView,
221 camera_bind_group: &wgpu::BindGroup,
222 ) {
223 if self.vertices.is_empty() {
224 return;
225 }
226
227 let vertex_buffer = gpu.device.create_buffer_init(&wgpu::util::BufferInitDescriptor {
228 label: Some("geom_vertex_buffer"),
229 contents: bytemuck::cast_slice(&self.vertices),
230 usage: wgpu::BufferUsages::VERTEX,
231 });
232
233 let vertex_count = self.vertices.len() as u32;
234
235 {
236 let mut pass = encoder.begin_render_pass(&wgpu::RenderPassDescriptor {
237 label: Some("geom_render_pass"),
238 color_attachments: &[Some(wgpu::RenderPassColorAttachment {
239 view: target,
240 resolve_target: None,
241 ops: wgpu::Operations {
242 load: wgpu::LoadOp::Load, store: wgpu::StoreOp::Store,
244 },
245 })],
246 depth_stencil_attachment: None,
247 timestamp_writes: None,
248 occlusion_query_set: None,
249 });
250
251 pass.set_pipeline(&self.pipeline);
252 pass.set_bind_group(0, camera_bind_group, &[]);
253 pass.set_vertex_buffer(0, vertex_buffer.slice(..));
254 pass.draw(0..vertex_count, 0..1);
255 }
256
257 self.vertices.clear();
258 }
259
260 pub fn clear(&mut self) {
262 self.vertices.clear();
263 }
264}
265
266#[cfg(test)]
267mod tests {
268 use super::*;
269
270 #[test]
271 fn geo_vertex_is_24_bytes() {
272 assert_eq!(std::mem::size_of::<GeoVertex>(), 24);
274 }
275
276 #[test]
277 fn line_quad_geometry_is_correct() {
278 let (x1, y1, x2, y2) = (0.0f32, 0.0, 10.0, 0.0);
280 let thickness = 2.0f32;
281 let dx = x2 - x1;
282 let dy = y2 - y1;
283 let len = (dx * dx + dy * dy).sqrt();
284 let half = thickness * 0.5;
285 let nx = -dy / len * half;
286 let ny = dx / len * half;
287
288 assert!((nx - 0.0).abs() < 1e-6, "nx should be 0 for horizontal line");
290 assert!((ny - 1.0).abs() < 1e-6, "ny should be 1 for horizontal line");
291 }
292
293 #[test]
294 fn diagonal_line_perpendicular() {
295 let (x1, y1, x2, y2) = (0.0f32, 0.0, 10.0, 10.0);
296 let thickness = 2.0f32;
297 let dx = x2 - x1;
298 let dy = y2 - y1;
299 let len = (dx * dx + dy * dy).sqrt();
300 let half = thickness * 0.5;
301 let nx = -dy / len * half;
302 let ny = dx / len * half;
303
304 let perp_len = (nx * nx + ny * ny).sqrt();
306 assert!((perp_len - 1.0).abs() < 1e-6, "perpendicular length should be half-thickness");
307 }
308}