Skip to main content

game_toolkit_gfx/
sprite.rs

1use std::collections::HashMap;
2
3use bytemuck::{Pod, Zeroable};
4use wgpu::util::DeviceExt;
5
6use crate::target::Targets;
7use crate::texture::{TextureId, TextureRegistry};
8
9#[repr(C)]
10#[derive(Copy, Clone, Pod, Zeroable)]
11struct QuadVertex {
12    pos: [f32; 2],
13    uv: [f32; 2],
14}
15
16const QUAD_VERTS: &[QuadVertex] = &[
17    QuadVertex {
18        pos: [0.0, 0.0],
19        uv: [0.0, 0.0],
20    },
21    QuadVertex {
22        pos: [1.0, 0.0],
23        uv: [1.0, 0.0],
24    },
25    QuadVertex {
26        pos: [1.0, 1.0],
27        uv: [1.0, 1.0],
28    },
29    QuadVertex {
30        pos: [0.0, 1.0],
31        uv: [0.0, 1.0],
32    },
33];
34const QUAD_INDICES: &[u16] = &[0, 1, 2, 0, 2, 3];
35
36#[repr(C)]
37#[derive(Copy, Clone, Pod, Zeroable, Debug)]
38pub struct SpriteInstance {
39    pub pos: [f32; 2],
40    pub size: [f32; 2],
41    pub uv_min: [f32; 2],
42    pub uv_max: [f32; 2],
43    pub color: [f32; 4],
44    pub rotation: f32,
45    pub _pad: [f32; 3],
46}
47
48impl SpriteInstance {
49    pub fn at(pos: [f32; 2], size: [f32; 2]) -> Self {
50        Self {
51            pos,
52            size,
53            uv_min: [0.0, 0.0],
54            uv_max: [1.0, 1.0],
55            color: [1.0; 4],
56            rotation: 0.0,
57            _pad: [0.0; 3],
58        }
59    }
60    pub fn with_color(mut self, c: [f32; 4]) -> Self {
61        self.color = c;
62        self
63    }
64    pub fn with_rotation(mut self, r: f32) -> Self {
65        self.rotation = r;
66        self
67    }
68    pub fn with_uv(mut self, min: [f32; 2], max: [f32; 2]) -> Self {
69        self.uv_min = min;
70        self.uv_max = max;
71        self
72    }
73}
74
75#[derive(Copy, Clone, Debug, PartialEq, Eq, Hash)]
76pub enum BlendMode {
77    Alpha,
78    Additive,
79    Premultiplied,
80}
81
82#[derive(Copy, Clone, PartialEq, Eq, Hash)]
83struct BatchKey {
84    texture: TextureId,
85    layer: i16,
86    blend: BlendMode,
87}
88
89pub(crate) struct SpriteBatcher {
90    quad_vb: wgpu::Buffer,
91    quad_ib: wgpu::Buffer,
92    instance_vb: wgpu::Buffer,
93    instance_capacity: usize,
94    pipelines: HashMap<BlendMode, wgpu::RenderPipeline>,
95    pending: Vec<(BatchKey, SpriteInstance)>,
96}
97
98impl SpriteBatcher {
99    pub fn new(
100        device: &wgpu::Device,
101        surface_format: wgpu::TextureFormat,
102        camera_bgl: &wgpu::BindGroupLayout,
103        texture_bgl: &wgpu::BindGroupLayout,
104        sample_count: u32,
105        depth_format: Option<wgpu::TextureFormat>,
106    ) -> Self {
107        let multisample = crate::target::multisample(sample_count);
108        let depth_stencil = depth_format.map(crate::target::no_write_depth);
109        let quad_vb = device.create_buffer_init(&wgpu::util::BufferInitDescriptor {
110            label: Some("sprite.quad_vb"),
111            contents: bytemuck::cast_slice(QUAD_VERTS),
112            usage: wgpu::BufferUsages::VERTEX,
113        });
114        let quad_ib = device.create_buffer_init(&wgpu::util::BufferInitDescriptor {
115            label: Some("sprite.quad_ib"),
116            contents: bytemuck::cast_slice(QUAD_INDICES),
117            usage: wgpu::BufferUsages::INDEX,
118        });
119
120        let instance_capacity = 4096usize;
121        let instance_vb = device.create_buffer(&wgpu::BufferDescriptor {
122            label: Some("sprite.instances"),
123            size: (instance_capacity * std::mem::size_of::<SpriteInstance>()) as u64,
124            usage: wgpu::BufferUsages::VERTEX | wgpu::BufferUsages::COPY_DST,
125            mapped_at_creation: false,
126        });
127
128        let shader = device.create_shader_module(wgpu::ShaderModuleDescriptor {
129            label: Some("sprite.shader"),
130            source: wgpu::ShaderSource::Wgsl(include_str!("sprite.wgsl").into()),
131        });
132
133        let layout = device.create_pipeline_layout(&wgpu::PipelineLayoutDescriptor {
134            label: Some("sprite.layout"),
135            bind_group_layouts: &[Some(camera_bgl), Some(texture_bgl)],
136            immediate_size: 0,
137        });
138
139        let make_pipeline = |blend: wgpu::BlendState, label: &'static str| {
140            device.create_render_pipeline(&wgpu::RenderPipelineDescriptor {
141                label: Some(label),
142                layout: Some(&layout),
143                vertex: wgpu::VertexState {
144                    module: &shader,
145                    entry_point: Some("vs_main"),
146                    compilation_options: Default::default(),
147                    buffers: &[
148                        wgpu::VertexBufferLayout {
149                            array_stride: std::mem::size_of::<QuadVertex>() as u64,
150                            step_mode: wgpu::VertexStepMode::Vertex,
151                            attributes: &wgpu::vertex_attr_array![0 => Float32x2, 1 => Float32x2],
152                        },
153                        wgpu::VertexBufferLayout {
154                            array_stride: std::mem::size_of::<SpriteInstance>() as u64,
155                            step_mode: wgpu::VertexStepMode::Instance,
156                            attributes: &wgpu::vertex_attr_array![
157                                2 => Float32x2,
158                                3 => Float32x2,
159                                4 => Float32x2,
160                                5 => Float32x2,
161                                6 => Float32x4,
162                                7 => Float32,
163                            ],
164                        },
165                    ],
166                },
167                fragment: Some(wgpu::FragmentState {
168                    module: &shader,
169                    entry_point: Some("fs_main"),
170                    compilation_options: Default::default(),
171                    targets: &[Some(wgpu::ColorTargetState {
172                        format: surface_format,
173                        blend: Some(blend),
174                        write_mask: wgpu::ColorWrites::ALL,
175                    })],
176                }),
177                primitive: wgpu::PrimitiveState::default(),
178                depth_stencil: depth_stencil.clone(),
179                multisample,
180                multiview_mask: None,
181                cache: None,
182            })
183        };
184
185        let mut pipelines = HashMap::new();
186        pipelines.insert(
187            BlendMode::Alpha,
188            make_pipeline(wgpu::BlendState::ALPHA_BLENDING, "sprite.alpha"),
189        );
190        pipelines.insert(
191            BlendMode::Premultiplied,
192            make_pipeline(
193                wgpu::BlendState::PREMULTIPLIED_ALPHA_BLENDING,
194                "sprite.premul",
195            ),
196        );
197        pipelines.insert(
198            BlendMode::Additive,
199            make_pipeline(
200                wgpu::BlendState {
201                    color: wgpu::BlendComponent {
202                        src_factor: wgpu::BlendFactor::SrcAlpha,
203                        dst_factor: wgpu::BlendFactor::One,
204                        operation: wgpu::BlendOperation::Add,
205                    },
206                    alpha: wgpu::BlendComponent::OVER,
207                },
208                "sprite.add",
209            ),
210        );
211
212        Self {
213            quad_vb,
214            quad_ib,
215            instance_vb,
216            instance_capacity,
217            pipelines,
218            pending: Vec::new(),
219        }
220    }
221
222    pub fn draw(&mut self, tex: TextureId, layer: i16, blend: BlendMode, inst: SpriteInstance) {
223        self.pending.push((
224            BatchKey {
225                texture: tex,
226                layer,
227                blend,
228            },
229            inst,
230        ));
231    }
232
233    /// Record every layer that has pending sprites, for cross-batcher interleaving.
234    pub fn collect_layers(&self, out: &mut std::collections::BTreeSet<i16>) {
235        out.extend(self.pending.iter().map(|(k, _)| k.layer));
236    }
237
238    /// Sort all pending sprites by layer (then blend, then texture) and upload them to the
239    /// instance buffer in one write. Must run before any [`SpriteBatcher::draw_layer`]: the
240    /// buffer is written once per frame because `queue.write_buffer` does not interleave with
241    /// encoder passes, so a per-pass write would clobber the earlier layers' data.
242    pub fn upload(&mut self, device: &wgpu::Device, queue: &wgpu::Queue) {
243        if self.pending.is_empty() {
244            return;
245        }
246        self.pending.sort_by(|a, b| {
247            a.0.layer
248                .cmp(&b.0.layer)
249                .then((a.0.blend as u8).cmp(&(b.0.blend as u8)))
250                .then(a.0.texture.0.cmp(&b.0.texture.0))
251        });
252        if self.pending.len() > self.instance_capacity {
253            self.instance_capacity = self.pending.len().next_power_of_two();
254            self.instance_vb = device.create_buffer(&wgpu::BufferDescriptor {
255                label: Some("sprite.instances"),
256                size: (self.instance_capacity * std::mem::size_of::<SpriteInstance>()) as u64,
257                usage: wgpu::BufferUsages::VERTEX | wgpu::BufferUsages::COPY_DST,
258                mapped_at_creation: false,
259            });
260        }
261        let flat: Vec<SpriteInstance> = self.pending.iter().map(|(_, i)| *i).collect();
262        queue.write_buffer(&self.instance_vb, 0, bytemuck::cast_slice(&flat));
263    }
264
265    /// Draw the sprites on `layer` (already uploaded by [`SpriteBatcher::upload`]) into the
266    /// already-cleared target, grouped by blend then texture to minimize state changes.
267    pub fn draw_layer(
268        &self,
269        layer: i16,
270        encoder: &mut wgpu::CommandEncoder,
271        targets: &Targets,
272        camera_bg: &wgpu::BindGroup,
273        textures: &TextureRegistry,
274    ) {
275        // `pending` is sorted by layer, so this layer's sprites are a contiguous range.
276        let lo = self.pending.partition_point(|(k, _)| k.layer < layer);
277        let hi = self.pending.partition_point(|(k, _)| k.layer <= layer);
278        if lo == hi {
279            return;
280        }
281
282        let mut pass = encoder.begin_render_pass(&wgpu::RenderPassDescriptor {
283            label: Some("sprite.pass"),
284            color_attachments: &[Some(targets.color_attachment(wgpu::LoadOp::Load))],
285            depth_stencil_attachment: targets.depth_attachment(wgpu::LoadOp::Load),
286            occlusion_query_set: None,
287            timestamp_writes: None,
288            multiview_mask: None,
289        });
290
291        pass.set_bind_group(0, camera_bg, &[]);
292        pass.set_vertex_buffer(0, self.quad_vb.slice(..));
293        pass.set_index_buffer(self.quad_ib.slice(..), wgpu::IndexFormat::Uint16);
294        pass.set_vertex_buffer(1, self.instance_vb.slice(..));
295
296        let mut i = lo;
297        while i < hi {
298            let key = self.pending[i].0;
299            let start = i;
300            while i < hi
301                && self.pending[i].0.blend == key.blend
302                && self.pending[i].0.texture == key.texture
303            {
304                i += 1;
305            }
306            let count = (i - start) as u32;
307            pass.set_pipeline(&self.pipelines[&key.blend]);
308            pass.set_bind_group(1, textures.bind_group(key.texture), &[]);
309            pass.draw_indexed(0..6, 0, (start as u32)..(start as u32 + count));
310        }
311    }
312
313    pub fn clear(&mut self) {
314        self.pending.clear();
315    }
316}