1use bevy_ecs::prelude::*;
13use glam::{Vec2, Vec3};
14use bytemuck::{Pod, Zeroable};
15use wgpu::{self, VertexBufferLayout, VertexAttribute, VertexFormat, VertexStepMode};
16use wgpu::util::DeviceExt;
17
18use super::buffer::Vertex;
19
20#[repr(C)]
41#[derive(Copy, Clone, Debug, Pod, Zeroable)]
42pub struct SpriteVertex {
43 pub position: [f32; 3], pub texcoord: [f32; 2],
45 pub color: [f32; 3], }
47
48impl Vertex for SpriteVertex {
49 fn layout() -> VertexBufferLayout<'static> {
50 const ATTRIBUTES: &[VertexAttribute] = &[
51 VertexAttribute {
52 offset: 0,
53 shader_location: 0,
54 format: VertexFormat::Float32x3,
55 },
56 VertexAttribute {
57 offset: 12,
58 shader_location: 1,
59 format: VertexFormat::Float32x2,
60 },
61 VertexAttribute {
62 offset: 20,
63 shader_location: 2,
64 format: VertexFormat::Float32x3,
65 },
66 ];
67
68 VertexBufferLayout {
69 array_stride: std::mem::size_of::<SpriteVertex>() as u64,
70 step_mode: VertexStepMode::Vertex,
71 attributes: ATTRIBUTES,
72 }
73 }
74}
75
76#[derive(Debug, Clone, Copy, PartialEq)]
90pub struct AtlasRect {
91 pub u_min: f32,
93 pub v_min: f32,
95 pub u_max: f32,
97 pub v_max: f32,
99}
100
101impl AtlasRect {
102 pub fn new(u_min: f32, v_min: f32, u_max: f32, v_max: f32) -> Self {
103 Self { u_min, v_min, u_max, v_max }
104 }
105
106 pub fn full() -> Self {
108 Self { u_min: 0.0, v_min: 0.0, u_max: 1.0, v_max: 1.0 }
109 }
110
111 pub fn width(&self) -> f32 { self.u_max - self.u_min }
112 pub fn height(&self) -> f32 { self.v_max - self.v_min }
113}
114
115impl Default for AtlasRect {
116 fn default() -> Self {
117 Self::full()
118 }
119}
120
121pub struct TextureAtlas {
136 pub width: u32,
138 pub height: u32,
140 rects: std::collections::HashMap<String, AtlasRect>,
142}
143
144impl TextureAtlas {
145 pub fn new(width: u32, height: u32) -> Self {
146 Self {
147 width,
148 height,
149 rects: std::collections::HashMap::new(),
150 }
151 }
152
153 pub fn add_rect(&mut self, name: &str, rect: AtlasRect) {
155 self.rects.insert(name.to_string(), rect);
156 }
157
158 pub fn add_rect_pixels(&mut self, name: &str, x: u32, y: u32, w: u32, h: u32) {
160 let rect = AtlasRect::new(
161 x as f32 / self.width as f32,
162 y as f32 / self.height as f32,
163 (x + w) as f32 / self.width as f32,
164 (y + h) as f32 / self.height as f32,
165 );
166 self.rects.insert(name.to_string(), rect);
167 }
168
169 pub fn get_rect(&self, name: &str) -> Option<&AtlasRect> {
171 self.rects.get(name)
172 }
173
174 pub fn rect_count(&self) -> usize {
176 self.rects.len()
177 }
178
179 pub fn from_grid(width: u32, height: u32, cols: u32, rows: u32) -> Self {
181 let mut atlas = Self::new(width, height);
182 let cell_w = 1.0 / cols as f32;
183 let cell_h = 1.0 / rows as f32;
184 for row in 0..rows {
185 for col in 0..cols {
186 let name = format!("{}_{}", col, row);
187 atlas.add_rect(&name, AtlasRect::new(
188 col as f32 * cell_w,
189 row as f32 * cell_h,
190 (col + 1) as f32 * cell_w,
191 (row + 1) as f32 * cell_h,
192 ));
193 }
194 }
195 atlas
196 }
197}
198
199#[derive(Debug, Clone, Component)]
219pub struct Sprite {
220 pub size: Vec2,
222 pub color: [f32; 3],
224 pub atlas_rect: AtlasRect,
226 pub flip_x: bool,
228 pub flip_y: bool,
230 pub z_order: f32,
232}
233
234impl Default for Sprite {
235 fn default() -> Self {
236 Self {
237 size: Vec2::new(64.0, 64.0),
238 color: [1.0, 1.0, 1.0],
239 atlas_rect: AtlasRect::full(),
240 flip_x: false,
241 flip_y: false,
242 z_order: 0.0,
243 }
244 }
245}
246
247#[derive(Default)]
251pub struct SpriteBatch {
252 pub vertices: Vec<SpriteVertex>,
254}
255
256impl SpriteBatch {
257 pub fn new() -> Self {
258 Self::default()
259 }
260
261 pub fn clear(&mut self) {
262 self.vertices.clear();
263 }
264
265 pub fn add_sprite(&mut self, position: Vec3, sprite: &Sprite) {
267 let half = sprite.size * 0.5;
268 let r = &sprite.atlas_rect;
269
270 let (u_min, u_max) = if sprite.flip_x { (r.u_max, r.u_min) } else { (r.u_min, r.u_max) };
271 let (v_min, v_max) = if sprite.flip_y { (r.v_max, r.v_min) } else { (r.v_min, r.v_max) };
272
273 let z = sprite.z_order;
274 let c = sprite.color;
275
276 let tl = SpriteVertex { position: [position.x - half.x, position.y + half.y, z], texcoord: [u_min, v_min], color: c };
278 let bl = SpriteVertex { position: [position.x - half.x, position.y - half.y, z], texcoord: [u_min, v_max], color: c };
279 let br = SpriteVertex { position: [position.x + half.x, position.y - half.y, z], texcoord: [u_max, v_max], color: c };
280 let tr = SpriteVertex { position: [position.x + half.x, position.y + half.y, z], texcoord: [u_max, v_min], color: c };
281
282 self.vertices.extend_from_slice(&[tl, bl, br, tl, br, tr]);
283 }
284
285 pub fn sprite_count(&self) -> usize {
287 self.vertices.len() / 6
288 }
289
290 pub fn sort_by_z_order(&mut self) {
292 let sprite_count = self.sprite_count();
294 if sprite_count <= 1 { return; }
295
296 let mut sprites: Vec<[SpriteVertex; 6]> = Vec::with_capacity(sprite_count);
297 for chunk in self.vertices.chunks_exact(6) {
298 sprites.push([chunk[0], chunk[1], chunk[2], chunk[3], chunk[4], chunk[5]]);
299 }
300
301 sprites.sort_by(|a, b| a[0].position[2].partial_cmp(&b[0].position[2]).unwrap_or(std::cmp::Ordering::Equal));
302
303 self.vertices.clear();
304 for sprite in sprites {
305 self.vertices.extend_from_slice(&sprite);
306 }
307 }
308}
309
310const SPRITE_SHADER: &str = include_str!("../shaders/sprite.wgsl");
315
316#[repr(C)]
318#[derive(Copy, Clone, Pod, Zeroable)]
319pub struct OrthoUniform {
320 pub projection: [[f32; 4]; 4],
321}
322
323pub struct SpriteRenderer {
325 pub pipeline: wgpu::RenderPipeline,
326 pub ortho_buffer: wgpu::Buffer,
327 pub ortho_bind_group: wgpu::BindGroup,
328 pub ortho_bind_group_layout: wgpu::BindGroupLayout,
329 pub texture_bind_group_layout: wgpu::BindGroupLayout,
330 cached_vb: Option<(wgpu::Buffer, u64)>,
332}
333
334impl SpriteRenderer {
335 pub fn new(device: &super::RenderDevice, format: wgpu::TextureFormat) -> Self {
337 let shader = device.device().create_shader_module(wgpu::ShaderModuleDescriptor {
338 label: Some("Sprite Shader"),
339 source: wgpu::ShaderSource::Wgsl(SPRITE_SHADER.into()),
340 });
341
342 let ortho_bgl = device.device().create_bind_group_layout(&wgpu::BindGroupLayoutDescriptor {
344 label: Some("Sprite Ortho BGL"),
345 entries: &[wgpu::BindGroupLayoutEntry {
346 binding: 0,
347 visibility: wgpu::ShaderStages::VERTEX,
348 ty: wgpu::BindingType::Buffer {
349 ty: wgpu::BufferBindingType::Uniform,
350 has_dynamic_offset: false,
351 min_binding_size: None,
352 },
353 count: None,
354 }],
355 });
356
357 let tex_bgl = device.device().create_bind_group_layout(&wgpu::BindGroupLayoutDescriptor {
359 label: Some("Sprite Texture BGL"),
360 entries: &[
361 wgpu::BindGroupLayoutEntry {
362 binding: 0,
363 visibility: wgpu::ShaderStages::FRAGMENT,
364 ty: wgpu::BindingType::Texture {
365 sample_type: wgpu::TextureSampleType::Float { filterable: true },
366 view_dimension: wgpu::TextureViewDimension::D2,
367 multisampled: false,
368 },
369 count: None,
370 },
371 wgpu::BindGroupLayoutEntry {
372 binding: 1,
373 visibility: wgpu::ShaderStages::FRAGMENT,
374 ty: wgpu::BindingType::Sampler(wgpu::SamplerBindingType::Filtering),
375 count: None,
376 },
377 ],
378 });
379
380 let pipeline_layout = device.device().create_pipeline_layout(&wgpu::PipelineLayoutDescriptor {
381 label: Some("Sprite Pipeline Layout"),
382 bind_group_layouts: &[&ortho_bgl, &tex_bgl],
383 push_constant_ranges: &[],
384 });
385
386 let pipeline = device.device().create_render_pipeline(&wgpu::RenderPipelineDescriptor {
387 label: Some("Sprite Pipeline"),
388 layout: Some(&pipeline_layout),
389 vertex: wgpu::VertexState {
390 module: &shader,
391 entry_point: "vs_main",
392 buffers: &[SpriteVertex::layout()],
393 },
394 fragment: Some(wgpu::FragmentState {
395 module: &shader,
396 entry_point: "fs_main",
397 targets: &[Some(wgpu::ColorTargetState {
398 format,
399 blend: Some(wgpu::BlendState::ALPHA_BLENDING),
400 write_mask: wgpu::ColorWrites::ALL,
401 })],
402 }),
403 primitive: wgpu::PrimitiveState {
404 topology: wgpu::PrimitiveTopology::TriangleList,
405 ..Default::default()
406 },
407 depth_stencil: None,
408 multisample: wgpu::MultisampleState::default(),
409 multiview: None,
410 });
411
412 let initial = OrthoUniform {
414 projection: glam::Mat4::IDENTITY.to_cols_array_2d(),
415 };
416 let ortho_buffer = device.device().create_buffer_init(&wgpu::util::BufferInitDescriptor {
417 label: Some("Sprite Ortho UB"),
418 contents: bytemuck::bytes_of(&initial),
419 usage: wgpu::BufferUsages::UNIFORM | wgpu::BufferUsages::COPY_DST,
420 });
421
422 let ortho_bg = device.device().create_bind_group(&wgpu::BindGroupDescriptor {
423 label: Some("Sprite Ortho BG"),
424 layout: &ortho_bgl,
425 entries: &[wgpu::BindGroupEntry {
426 binding: 0,
427 resource: ortho_buffer.as_entire_binding(),
428 }],
429 });
430
431 Self {
432 pipeline,
433 ortho_buffer,
434 ortho_bind_group: ortho_bg,
435 ortho_bind_group_layout: ortho_bgl,
436 texture_bind_group_layout: tex_bgl,
437 cached_vb: None,
438 }
439 }
440
441 pub fn render(
443 &mut self,
444 device: &super::RenderDevice,
445 encoder: &mut wgpu::CommandEncoder,
446 target: &wgpu::TextureView,
447 batch: &SpriteBatch,
448 texture_bind_group: &wgpu::BindGroup,
449 screen_width: f32,
450 screen_height: f32,
451 ) {
452 if batch.vertices.is_empty() {
453 return;
454 }
455
456 let ortho = glam::Mat4::orthographic_lh(0.0, screen_width, screen_height, 0.0, -1.0, 1.0);
458 let uniform = OrthoUniform {
459 projection: ortho.to_cols_array_2d(),
460 };
461 device.queue().write_buffer(&self.ortho_buffer, 0, bytemuck::bytes_of(&uniform));
462
463 let data = bytemuck::cast_slice(&batch.vertices);
465 let needed = data.len() as u64;
466 let reuse = self.cached_vb.as_ref().map_or(false, |(_, cap)| *cap >= needed);
467 if !reuse {
468 self.cached_vb = Some((
469 device.device().create_buffer(&wgpu::BufferDescriptor {
470 label: Some("Sprite VB (cached)"),
471 size: needed,
472 usage: wgpu::BufferUsages::VERTEX | wgpu::BufferUsages::COPY_DST,
473 mapped_at_creation: false,
474 }),
475 needed,
476 ));
477 }
478 let vb = &self.cached_vb.as_ref().unwrap().0;
479 device.queue().write_buffer(vb, 0, data);
480
481 {
482 let mut rp = encoder.begin_render_pass(&wgpu::RenderPassDescriptor {
483 label: Some("Sprite Pass"),
484 color_attachments: &[Some(wgpu::RenderPassColorAttachment {
485 view: target,
486 resolve_target: None,
487 ops: wgpu::Operations {
488 load: wgpu::LoadOp::Load,
489 store: wgpu::StoreOp::Store,
490 },
491 })],
492 depth_stencil_attachment: None,
493 timestamp_writes: None,
494 occlusion_query_set: None,
495 });
496
497 rp.set_pipeline(&self.pipeline);
498 rp.set_bind_group(0, &self.ortho_bind_group, &[]);
499 rp.set_bind_group(1, texture_bind_group, &[]);
500 rp.set_vertex_buffer(0, vb.slice(..));
501 rp.draw(0..batch.vertices.len() as u32, 0..1);
502 }
503 }
504}
505
506#[cfg(test)]
507mod tests {
508 use super::*;
509
510 #[test]
511 fn test_sprite_vertex_size() {
512 assert_eq!(std::mem::size_of::<SpriteVertex>(), 32);
513 }
514
515 #[test]
516 fn test_atlas_rect() {
517 let full = AtlasRect::full();
518 assert_eq!(full.width(), 1.0);
519 assert_eq!(full.height(), 1.0);
520 }
521
522 #[test]
523 fn test_texture_atlas_grid() {
524 let atlas = TextureAtlas::from_grid(256, 256, 4, 4);
525 assert_eq!(atlas.rect_count(), 16);
526 let r = atlas.get_rect("0_0").unwrap();
527 assert!((r.u_min - 0.0).abs() < 0.001);
528 assert!((r.u_max - 0.25).abs() < 0.001);
529 }
530
531 #[test]
532 fn test_sprite_batch() {
533 let mut batch = SpriteBatch::new();
534 let sprite = Sprite::default();
535
536 batch.add_sprite(Vec3::new(100.0, 200.0, 0.0), &sprite);
537 assert_eq!(batch.sprite_count(), 1);
538 assert_eq!(batch.vertices.len(), 6);
539
540 batch.add_sprite(Vec3::new(300.0, 200.0, 1.0), &sprite);
541 assert_eq!(batch.sprite_count(), 2);
542 }
543
544 #[test]
545 fn test_sprite_batch_z_sort() {
546 let mut batch = SpriteBatch::new();
547 let s1 = Sprite { z_order: 2.0, ..Default::default() };
548 let s2 = Sprite { z_order: 0.0, ..Default::default() };
549 let s3 = Sprite { z_order: 1.0, ..Default::default() };
550
551 batch.add_sprite(Vec3::ZERO, &s1);
552 batch.add_sprite(Vec3::ZERO, &s2);
553 batch.add_sprite(Vec3::ZERO, &s3);
554
555 batch.sort_by_z_order();
556
557 assert_eq!(batch.vertices[0].position[2], 0.0);
559 assert_eq!(batch.vertices[6].position[2], 1.0);
560 assert_eq!(batch.vertices[12].position[2], 2.0);
561 }
562}