1use std::collections::HashMap;
5
6use bytemuck::{Pod, Zeroable};
7use wgpu::util::DeviceExt;
8
9use crate::target::Targets;
10use crate::transform::Mat4;
11
12#[derive(Copy, Clone, Debug, PartialEq, Eq, Hash, PartialOrd, Ord)]
14pub struct MeshId(pub u32);
15
16#[repr(C)]
18#[derive(Copy, Clone, Pod, Zeroable, Debug)]
19pub struct MeshVertex {
20 pub pos: [f32; 3],
21 pub normal: [f32; 3],
22}
23
24impl MeshVertex {
25 pub fn new(pos: [f32; 3], normal: [f32; 3]) -> Self {
26 Self { pos, normal }
27 }
28}
29
30#[repr(C)]
32#[derive(Copy, Clone, Pod, Zeroable, Debug)]
33pub struct MeshInstance {
34 pub model: Mat4,
35 pub color: [f32; 4],
36}
37
38impl MeshInstance {
39 pub fn new(model: Mat4, color: [f32; 4]) -> Self {
40 Self { model, color }
41 }
42}
43
44struct MeshGpu {
45 vertex_buf: wgpu::Buffer,
46 index_buf: wgpu::Buffer,
47 index_count: u32,
48}
49
50pub(crate) struct MeshRegistry {
52 map: HashMap<MeshId, MeshGpu>,
53 next: u32,
54}
55
56impl MeshRegistry {
57 pub fn new() -> Self {
58 Self {
59 map: HashMap::new(),
60 next: 1,
61 }
62 }
63
64 pub fn create(
65 &mut self,
66 device: &wgpu::Device,
67 vertices: &[MeshVertex],
68 indices: &[u16],
69 ) -> MeshId {
70 let vertex_buf = device.create_buffer_init(&wgpu::util::BufferInitDescriptor {
71 label: Some("mesh.vertices"),
72 contents: bytemuck::cast_slice(vertices),
73 usage: wgpu::BufferUsages::VERTEX,
74 });
75 let index_buf = device.create_buffer_init(&wgpu::util::BufferInitDescriptor {
76 label: Some("mesh.indices"),
77 contents: bytemuck::cast_slice(indices),
78 usage: wgpu::BufferUsages::INDEX,
79 });
80 let id = MeshId(self.next);
81 self.next += 1;
82 self.map.insert(
83 id,
84 MeshGpu {
85 vertex_buf,
86 index_buf,
87 index_count: indices.len() as u32,
88 },
89 );
90 id
91 }
92}
93
94pub(crate) struct MeshBatcher {
95 pipeline: wgpu::RenderPipeline,
96 instance_vb: wgpu::Buffer,
97 capacity: usize,
98 pending: Vec<(MeshId, MeshInstance)>,
99}
100
101impl MeshBatcher {
102 pub fn new(
103 device: &wgpu::Device,
104 surface_format: wgpu::TextureFormat,
105 camera_bgl: &wgpu::BindGroupLayout,
106 sample_count: u32,
107 depth_format: Option<wgpu::TextureFormat>,
108 ) -> Self {
109 let capacity = 256usize;
110 let instance_vb = device.create_buffer(&wgpu::BufferDescriptor {
111 label: Some("mesh.instances"),
112 size: (capacity * std::mem::size_of::<MeshInstance>()) as u64,
113 usage: wgpu::BufferUsages::VERTEX | wgpu::BufferUsages::COPY_DST,
114 mapped_at_creation: false,
115 });
116
117 let shader = device.create_shader_module(wgpu::ShaderModuleDescriptor {
118 label: Some("mesh.shader"),
119 source: wgpu::ShaderSource::Wgsl(include_str!("mesh.wgsl").into()),
120 });
121 let layout = device.create_pipeline_layout(&wgpu::PipelineLayoutDescriptor {
122 label: Some("mesh.layout"),
123 bind_group_layouts: &[Some(camera_bgl)],
124 immediate_size: 0,
125 });
126
127 let pipeline = device.create_render_pipeline(&wgpu::RenderPipelineDescriptor {
128 label: Some("mesh.pipeline"),
129 layout: Some(&layout),
130 vertex: wgpu::VertexState {
131 module: &shader,
132 entry_point: Some("vs_main"),
133 compilation_options: Default::default(),
134 buffers: &[
135 wgpu::VertexBufferLayout {
136 array_stride: std::mem::size_of::<MeshVertex>() as u64,
137 step_mode: wgpu::VertexStepMode::Vertex,
138 attributes: &wgpu::vertex_attr_array![0 => Float32x3, 1 => Float32x3],
139 },
140 wgpu::VertexBufferLayout {
141 array_stride: std::mem::size_of::<MeshInstance>() as u64,
142 step_mode: wgpu::VertexStepMode::Instance,
143 attributes: &wgpu::vertex_attr_array![
144 2 => Float32x4,
145 3 => Float32x4,
146 4 => Float32x4,
147 5 => Float32x4,
148 6 => Float32x4,
149 ],
150 },
151 ],
152 },
153 fragment: Some(wgpu::FragmentState {
154 module: &shader,
155 entry_point: Some("fs_main"),
156 compilation_options: Default::default(),
157 targets: &[Some(wgpu::ColorTargetState {
158 format: surface_format,
159 blend: Some(wgpu::BlendState::ALPHA_BLENDING),
160 write_mask: wgpu::ColorWrites::ALL,
161 })],
162 }),
163 primitive: wgpu::PrimitiveState::default(),
166 depth_stencil: depth_format.map(crate::target::depth_test),
167 multisample: crate::target::multisample(sample_count),
168 multiview_mask: None,
169 cache: None,
170 });
171
172 Self {
173 pipeline,
174 instance_vb,
175 capacity,
176 pending: Vec::new(),
177 }
178 }
179
180 pub fn push(&mut self, id: MeshId, instance: MeshInstance) {
181 self.pending.push((id, instance));
182 }
183
184 pub fn draw(
185 &mut self,
186 device: &wgpu::Device,
187 queue: &wgpu::Queue,
188 registry: &MeshRegistry,
189 encoder: &mut wgpu::CommandEncoder,
190 targets: &Targets,
191 camera_bg: &wgpu::BindGroup,
192 ) {
193 if self.pending.is_empty() {
194 return;
195 }
196 self.pending.sort_by_key(|(id, _)| *id);
198 if self.pending.len() > self.capacity {
199 self.capacity = self.pending.len().next_power_of_two();
200 self.instance_vb = device.create_buffer(&wgpu::BufferDescriptor {
201 label: Some("mesh.instances"),
202 size: (self.capacity * std::mem::size_of::<MeshInstance>()) as u64,
203 usage: wgpu::BufferUsages::VERTEX | wgpu::BufferUsages::COPY_DST,
204 mapped_at_creation: false,
205 });
206 }
207 let flat: Vec<MeshInstance> = self.pending.iter().map(|(_, i)| *i).collect();
208 queue.write_buffer(&self.instance_vb, 0, bytemuck::cast_slice(&flat));
209
210 let mut pass = encoder.begin_render_pass(&wgpu::RenderPassDescriptor {
211 label: Some("mesh.pass"),
212 color_attachments: &[Some(targets.color_attachment(wgpu::LoadOp::Load))],
213 depth_stencil_attachment: targets.depth_attachment(wgpu::LoadOp::Load),
214 occlusion_query_set: None,
215 timestamp_writes: None,
216 multiview_mask: None,
217 });
218 pass.set_pipeline(&self.pipeline);
219 pass.set_bind_group(0, camera_bg, &[]);
220 pass.set_vertex_buffer(1, self.instance_vb.slice(..));
221
222 let mut i = 0;
223 while i < self.pending.len() {
224 let id = self.pending[i].0;
225 let start = i;
226 while i < self.pending.len() && self.pending[i].0 == id {
227 i += 1;
228 }
229 let Some(mesh) = registry.map.get(&id) else {
230 continue;
231 };
232 pass.set_vertex_buffer(0, mesh.vertex_buf.slice(..));
233 pass.set_index_buffer(mesh.index_buf.slice(..), wgpu::IndexFormat::Uint16);
234 pass.draw_indexed(0..mesh.index_count, 0, (start as u32)..(i as u32));
235 }
236 }
237
238 pub fn clear(&mut self) {
239 self.pending.clear();
240 }
241}