1use std::collections::HashMap;
2
3use bytemuck::{Pod, Zeroable};
4use wgpu::util::DeviceExt;
5
6use crate::target::Targets;
7use crate::texture::{TextureId, TextureRegistry};
8
9#[repr(C)]
10#[derive(Copy, Clone, Pod, Zeroable)]
11struct QuadVertex {
12 pos: [f32; 2],
13 uv: [f32; 2],
14}
15
16const QUAD_VERTS: &[QuadVertex] = &[
17 QuadVertex {
18 pos: [0.0, 0.0],
19 uv: [0.0, 0.0],
20 },
21 QuadVertex {
22 pos: [1.0, 0.0],
23 uv: [1.0, 0.0],
24 },
25 QuadVertex {
26 pos: [1.0, 1.0],
27 uv: [1.0, 1.0],
28 },
29 QuadVertex {
30 pos: [0.0, 1.0],
31 uv: [0.0, 1.0],
32 },
33];
34const QUAD_INDICES: &[u16] = &[0, 1, 2, 0, 2, 3];
35
36#[repr(C)]
37#[derive(Copy, Clone, Pod, Zeroable, Debug)]
38pub struct SpriteInstance {
39 pub pos: [f32; 2],
40 pub size: [f32; 2],
41 pub uv_min: [f32; 2],
42 pub uv_max: [f32; 2],
43 pub color: [f32; 4],
44 pub rotation: f32,
45 pub _pad: [f32; 3],
46}
47
48impl SpriteInstance {
49 pub fn at(pos: [f32; 2], size: [f32; 2]) -> Self {
50 Self {
51 pos,
52 size,
53 uv_min: [0.0, 0.0],
54 uv_max: [1.0, 1.0],
55 color: [1.0; 4],
56 rotation: 0.0,
57 _pad: [0.0; 3],
58 }
59 }
60 pub fn with_color(mut self, c: [f32; 4]) -> Self {
61 self.color = c;
62 self
63 }
64 pub fn with_rotation(mut self, r: f32) -> Self {
65 self.rotation = r;
66 self
67 }
68 pub fn with_uv(mut self, min: [f32; 2], max: [f32; 2]) -> Self {
69 self.uv_min = min;
70 self.uv_max = max;
71 self
72 }
73}
74
75#[derive(Copy, Clone, Debug, PartialEq, Eq, Hash)]
76pub enum BlendMode {
77 Alpha,
78 Additive,
79 Premultiplied,
80}
81
82#[derive(Copy, Clone, PartialEq, Eq, Hash)]
83struct BatchKey {
84 texture: TextureId,
85 layer: i16,
86 blend: BlendMode,
87}
88
89pub(crate) struct SpriteBatcher {
90 quad_vb: wgpu::Buffer,
91 quad_ib: wgpu::Buffer,
92 instance_vb: wgpu::Buffer,
93 instance_capacity: usize,
94 pipelines: HashMap<BlendMode, wgpu::RenderPipeline>,
95 pending: Vec<(BatchKey, SpriteInstance)>,
96}
97
98impl SpriteBatcher {
99 pub fn new(
100 device: &wgpu::Device,
101 surface_format: wgpu::TextureFormat,
102 camera_bgl: &wgpu::BindGroupLayout,
103 texture_bgl: &wgpu::BindGroupLayout,
104 sample_count: u32,
105 depth_format: Option<wgpu::TextureFormat>,
106 ) -> Self {
107 let multisample = crate::target::multisample(sample_count);
108 let depth_stencil = depth_format.map(crate::target::no_write_depth);
109 let quad_vb = device.create_buffer_init(&wgpu::util::BufferInitDescriptor {
110 label: Some("sprite.quad_vb"),
111 contents: bytemuck::cast_slice(QUAD_VERTS),
112 usage: wgpu::BufferUsages::VERTEX,
113 });
114 let quad_ib = device.create_buffer_init(&wgpu::util::BufferInitDescriptor {
115 label: Some("sprite.quad_ib"),
116 contents: bytemuck::cast_slice(QUAD_INDICES),
117 usage: wgpu::BufferUsages::INDEX,
118 });
119
120 let instance_capacity = 4096usize;
121 let instance_vb = device.create_buffer(&wgpu::BufferDescriptor {
122 label: Some("sprite.instances"),
123 size: (instance_capacity * std::mem::size_of::<SpriteInstance>()) as u64,
124 usage: wgpu::BufferUsages::VERTEX | wgpu::BufferUsages::COPY_DST,
125 mapped_at_creation: false,
126 });
127
128 let shader = device.create_shader_module(wgpu::ShaderModuleDescriptor {
129 label: Some("sprite.shader"),
130 source: wgpu::ShaderSource::Wgsl(include_str!("sprite.wgsl").into()),
131 });
132
133 let layout = device.create_pipeline_layout(&wgpu::PipelineLayoutDescriptor {
134 label: Some("sprite.layout"),
135 bind_group_layouts: &[Some(camera_bgl), Some(texture_bgl)],
136 immediate_size: 0,
137 });
138
139 let make_pipeline = |blend: wgpu::BlendState, label: &'static str| {
140 device.create_render_pipeline(&wgpu::RenderPipelineDescriptor {
141 label: Some(label),
142 layout: Some(&layout),
143 vertex: wgpu::VertexState {
144 module: &shader,
145 entry_point: Some("vs_main"),
146 compilation_options: Default::default(),
147 buffers: &[
148 wgpu::VertexBufferLayout {
149 array_stride: std::mem::size_of::<QuadVertex>() as u64,
150 step_mode: wgpu::VertexStepMode::Vertex,
151 attributes: &wgpu::vertex_attr_array![0 => Float32x2, 1 => Float32x2],
152 },
153 wgpu::VertexBufferLayout {
154 array_stride: std::mem::size_of::<SpriteInstance>() as u64,
155 step_mode: wgpu::VertexStepMode::Instance,
156 attributes: &wgpu::vertex_attr_array![
157 2 => Float32x2,
158 3 => Float32x2,
159 4 => Float32x2,
160 5 => Float32x2,
161 6 => Float32x4,
162 7 => Float32,
163 ],
164 },
165 ],
166 },
167 fragment: Some(wgpu::FragmentState {
168 module: &shader,
169 entry_point: Some("fs_main"),
170 compilation_options: Default::default(),
171 targets: &[Some(wgpu::ColorTargetState {
172 format: surface_format,
173 blend: Some(blend),
174 write_mask: wgpu::ColorWrites::ALL,
175 })],
176 }),
177 primitive: wgpu::PrimitiveState::default(),
178 depth_stencil: depth_stencil.clone(),
179 multisample,
180 multiview_mask: None,
181 cache: None,
182 })
183 };
184
185 let mut pipelines = HashMap::new();
186 pipelines.insert(
187 BlendMode::Alpha,
188 make_pipeline(wgpu::BlendState::ALPHA_BLENDING, "sprite.alpha"),
189 );
190 pipelines.insert(
191 BlendMode::Premultiplied,
192 make_pipeline(
193 wgpu::BlendState::PREMULTIPLIED_ALPHA_BLENDING,
194 "sprite.premul",
195 ),
196 );
197 pipelines.insert(
198 BlendMode::Additive,
199 make_pipeline(
200 wgpu::BlendState {
201 color: wgpu::BlendComponent {
202 src_factor: wgpu::BlendFactor::SrcAlpha,
203 dst_factor: wgpu::BlendFactor::One,
204 operation: wgpu::BlendOperation::Add,
205 },
206 alpha: wgpu::BlendComponent::OVER,
207 },
208 "sprite.add",
209 ),
210 );
211
212 Self {
213 quad_vb,
214 quad_ib,
215 instance_vb,
216 instance_capacity,
217 pipelines,
218 pending: Vec::new(),
219 }
220 }
221
222 pub fn draw(&mut self, tex: TextureId, layer: i16, blend: BlendMode, inst: SpriteInstance) {
223 self.pending.push((
224 BatchKey {
225 texture: tex,
226 layer,
227 blend,
228 },
229 inst,
230 ));
231 }
232
233 pub fn collect_layers(&self, out: &mut std::collections::BTreeSet<i16>) {
235 out.extend(self.pending.iter().map(|(k, _)| k.layer));
236 }
237
238 pub fn upload(&mut self, device: &wgpu::Device, queue: &wgpu::Queue) {
243 if self.pending.is_empty() {
244 return;
245 }
246 self.pending.sort_by(|a, b| {
247 a.0.layer
248 .cmp(&b.0.layer)
249 .then((a.0.blend as u8).cmp(&(b.0.blend as u8)))
250 .then(a.0.texture.0.cmp(&b.0.texture.0))
251 });
252 if self.pending.len() > self.instance_capacity {
253 self.instance_capacity = self.pending.len().next_power_of_two();
254 self.instance_vb = device.create_buffer(&wgpu::BufferDescriptor {
255 label: Some("sprite.instances"),
256 size: (self.instance_capacity * std::mem::size_of::<SpriteInstance>()) as u64,
257 usage: wgpu::BufferUsages::VERTEX | wgpu::BufferUsages::COPY_DST,
258 mapped_at_creation: false,
259 });
260 }
261 let flat: Vec<SpriteInstance> = self.pending.iter().map(|(_, i)| *i).collect();
262 queue.write_buffer(&self.instance_vb, 0, bytemuck::cast_slice(&flat));
263 }
264
265 pub fn draw_layer(
268 &self,
269 layer: i16,
270 encoder: &mut wgpu::CommandEncoder,
271 targets: &Targets,
272 camera_bg: &wgpu::BindGroup,
273 textures: &TextureRegistry,
274 ) {
275 let lo = self.pending.partition_point(|(k, _)| k.layer < layer);
277 let hi = self.pending.partition_point(|(k, _)| k.layer <= layer);
278 if lo == hi {
279 return;
280 }
281
282 let mut pass = encoder.begin_render_pass(&wgpu::RenderPassDescriptor {
283 label: Some("sprite.pass"),
284 color_attachments: &[Some(targets.color_attachment(wgpu::LoadOp::Load))],
285 depth_stencil_attachment: targets.depth_attachment(wgpu::LoadOp::Load),
286 occlusion_query_set: None,
287 timestamp_writes: None,
288 multiview_mask: None,
289 });
290
291 pass.set_bind_group(0, camera_bg, &[]);
292 pass.set_vertex_buffer(0, self.quad_vb.slice(..));
293 pass.set_index_buffer(self.quad_ib.slice(..), wgpu::IndexFormat::Uint16);
294 pass.set_vertex_buffer(1, self.instance_vb.slice(..));
295
296 let mut i = lo;
297 while i < hi {
298 let key = self.pending[i].0;
299 let start = i;
300 while i < hi
301 && self.pending[i].0.blend == key.blend
302 && self.pending[i].0.texture == key.texture
303 {
304 i += 1;
305 }
306 let count = (i - start) as u32;
307 pass.set_pipeline(&self.pipelines[&key.blend]);
308 pass.set_bind_group(1, textures.bind_group(key.texture), &[]);
309 pass.draw_indexed(0..6, 0, (start as u32)..(start as u32 + count));
310 }
311 }
312
313 pub fn clear(&mut self) {
314 self.pending.clear();
315 }
316}