bevy_sprite_render/render/
mod.rs

1use core::ops::Range;
2
3use crate::ComputedTextureSlices;
4use bevy_asset::{load_embedded_asset, AssetEvent, AssetId, AssetServer, Assets, Handle};
5use bevy_camera::visibility::ViewVisibility;
6use bevy_color::{ColorToComponents, LinearRgba};
7use bevy_core_pipeline::{
8    core_2d::{Transparent2d, CORE_2D_DEPTH_FORMAT},
9    tonemapping::{
10        get_lut_bind_group_layout_entries, get_lut_bindings, DebandDither, Tonemapping,
11        TonemappingLuts,
12    },
13};
14use bevy_derive::{Deref, DerefMut};
15use bevy_ecs::{
16    prelude::*,
17    query::ROQueryItem,
18    system::{lifetimeless::*, SystemParamItem},
19};
20use bevy_image::{BevyDefault, Image, TextureAtlasLayout};
21use bevy_math::{Affine3A, FloatOrd, Quat, Rect, Vec2, Vec4};
22use bevy_mesh::VertexBufferLayout;
23use bevy_platform::collections::HashMap;
24use bevy_render::view::{RenderVisibleEntities, RetainedViewEntity};
25use bevy_render::{
26    render_asset::RenderAssets,
27    render_phase::{
28        DrawFunctions, PhaseItem, PhaseItemExtraIndex, RenderCommand, RenderCommandResult,
29        SetItemPipeline, TrackedRenderPass, ViewSortedRenderPhases,
30    },
31    render_resource::{
32        binding_types::{sampler, texture_2d, uniform_buffer},
33        *,
34    },
35    renderer::{RenderDevice, RenderQueue},
36    sync_world::RenderEntity,
37    texture::{FallbackImage, GpuImage},
38    view::{ExtractedView, Msaa, ViewTarget, ViewUniform, ViewUniformOffset, ViewUniforms},
39    Extract,
40};
41use bevy_shader::{Shader, ShaderDefVal};
42use bevy_sprite::{Anchor, Sprite, SpriteScalingMode};
43use bevy_transform::components::GlobalTransform;
44use bevy_utils::default;
45use bytemuck::{Pod, Zeroable};
46use fixedbitset::FixedBitSet;
47
48#[derive(Resource)]
49pub struct SpritePipeline {
50    view_layout: BindGroupLayoutDescriptor,
51    material_layout: BindGroupLayoutDescriptor,
52    shader: Handle<Shader>,
53}
54
55pub fn init_sprite_pipeline(mut commands: Commands, asset_server: Res<AssetServer>) {
56    let tonemapping_lut_entries = get_lut_bind_group_layout_entries();
57    let view_layout = BindGroupLayoutDescriptor::new(
58        "sprite_view_layout",
59        &BindGroupLayoutEntries::sequential(
60            ShaderStages::VERTEX_FRAGMENT,
61            (
62                uniform_buffer::<ViewUniform>(true),
63                tonemapping_lut_entries[0].visibility(ShaderStages::FRAGMENT),
64                tonemapping_lut_entries[1].visibility(ShaderStages::FRAGMENT),
65            ),
66        ),
67    );
68
69    let material_layout = BindGroupLayoutDescriptor::new(
70        "sprite_material_layout",
71        &BindGroupLayoutEntries::sequential(
72            ShaderStages::FRAGMENT,
73            (
74                texture_2d(TextureSampleType::Float { filterable: true }),
75                sampler(SamplerBindingType::Filtering),
76            ),
77        ),
78    );
79
80    commands.insert_resource(SpritePipeline {
81        view_layout,
82        material_layout,
83        shader: load_embedded_asset!(asset_server.as_ref(), "sprite.wgsl"),
84    });
85}
86
87bitflags::bitflags! {
88    #[derive(Clone, Copy, Debug, PartialEq, Eq, Hash)]
89    #[repr(transparent)]
90    // NOTE: Apparently quadro drivers support up to 64x MSAA.
91    // MSAA uses the highest 3 bits for the MSAA log2(sample count) to support up to 128x MSAA.
92    pub struct SpritePipelineKey: u32 {
93        const NONE                              = 0;
94        const HDR                               = 1 << 0;
95        const TONEMAP_IN_SHADER                 = 1 << 1;
96        const DEBAND_DITHER                     = 1 << 2;
97        const MSAA_RESERVED_BITS                = Self::MSAA_MASK_BITS << Self::MSAA_SHIFT_BITS;
98        const TONEMAP_METHOD_RESERVED_BITS      = Self::TONEMAP_METHOD_MASK_BITS << Self::TONEMAP_METHOD_SHIFT_BITS;
99        const TONEMAP_METHOD_NONE               = 0 << Self::TONEMAP_METHOD_SHIFT_BITS;
100        const TONEMAP_METHOD_REINHARD           = 1 << Self::TONEMAP_METHOD_SHIFT_BITS;
101        const TONEMAP_METHOD_REINHARD_LUMINANCE = 2 << Self::TONEMAP_METHOD_SHIFT_BITS;
102        const TONEMAP_METHOD_ACES_FITTED        = 3 << Self::TONEMAP_METHOD_SHIFT_BITS;
103        const TONEMAP_METHOD_AGX                = 4 << Self::TONEMAP_METHOD_SHIFT_BITS;
104        const TONEMAP_METHOD_SOMEWHAT_BORING_DISPLAY_TRANSFORM = 5 << Self::TONEMAP_METHOD_SHIFT_BITS;
105        const TONEMAP_METHOD_TONY_MC_MAPFACE    = 6 << Self::TONEMAP_METHOD_SHIFT_BITS;
106        const TONEMAP_METHOD_BLENDER_FILMIC     = 7 << Self::TONEMAP_METHOD_SHIFT_BITS;
107    }
108}
109
110impl SpritePipelineKey {
111    const MSAA_MASK_BITS: u32 = 0b111;
112    const MSAA_SHIFT_BITS: u32 = 32 - Self::MSAA_MASK_BITS.count_ones();
113    const TONEMAP_METHOD_MASK_BITS: u32 = 0b111;
114    const TONEMAP_METHOD_SHIFT_BITS: u32 =
115        Self::MSAA_SHIFT_BITS - Self::TONEMAP_METHOD_MASK_BITS.count_ones();
116
117    #[inline]
118    pub const fn from_msaa_samples(msaa_samples: u32) -> Self {
119        let msaa_bits =
120            (msaa_samples.trailing_zeros() & Self::MSAA_MASK_BITS) << Self::MSAA_SHIFT_BITS;
121        Self::from_bits_retain(msaa_bits)
122    }
123
124    #[inline]
125    pub const fn msaa_samples(&self) -> u32 {
126        1 << ((self.bits() >> Self::MSAA_SHIFT_BITS) & Self::MSAA_MASK_BITS)
127    }
128
129    #[inline]
130    pub const fn from_hdr(hdr: bool) -> Self {
131        if hdr {
132            SpritePipelineKey::HDR
133        } else {
134            SpritePipelineKey::NONE
135        }
136    }
137}
138
139impl SpecializedRenderPipeline for SpritePipeline {
140    type Key = SpritePipelineKey;
141
142    fn specialize(&self, key: Self::Key) -> RenderPipelineDescriptor {
143        let mut shader_defs = Vec::new();
144        if key.contains(SpritePipelineKey::TONEMAP_IN_SHADER) {
145            shader_defs.push("TONEMAP_IN_SHADER".into());
146            shader_defs.push(ShaderDefVal::UInt(
147                "TONEMAPPING_LUT_TEXTURE_BINDING_INDEX".into(),
148                1,
149            ));
150            shader_defs.push(ShaderDefVal::UInt(
151                "TONEMAPPING_LUT_SAMPLER_BINDING_INDEX".into(),
152                2,
153            ));
154
155            let method = key.intersection(SpritePipelineKey::TONEMAP_METHOD_RESERVED_BITS);
156
157            if method == SpritePipelineKey::TONEMAP_METHOD_NONE {
158                shader_defs.push("TONEMAP_METHOD_NONE".into());
159            } else if method == SpritePipelineKey::TONEMAP_METHOD_REINHARD {
160                shader_defs.push("TONEMAP_METHOD_REINHARD".into());
161            } else if method == SpritePipelineKey::TONEMAP_METHOD_REINHARD_LUMINANCE {
162                shader_defs.push("TONEMAP_METHOD_REINHARD_LUMINANCE".into());
163            } else if method == SpritePipelineKey::TONEMAP_METHOD_ACES_FITTED {
164                shader_defs.push("TONEMAP_METHOD_ACES_FITTED".into());
165            } else if method == SpritePipelineKey::TONEMAP_METHOD_AGX {
166                shader_defs.push("TONEMAP_METHOD_AGX".into());
167            } else if method == SpritePipelineKey::TONEMAP_METHOD_SOMEWHAT_BORING_DISPLAY_TRANSFORM
168            {
169                shader_defs.push("TONEMAP_METHOD_SOMEWHAT_BORING_DISPLAY_TRANSFORM".into());
170            } else if method == SpritePipelineKey::TONEMAP_METHOD_BLENDER_FILMIC {
171                shader_defs.push("TONEMAP_METHOD_BLENDER_FILMIC".into());
172            } else if method == SpritePipelineKey::TONEMAP_METHOD_TONY_MC_MAPFACE {
173                shader_defs.push("TONEMAP_METHOD_TONY_MC_MAPFACE".into());
174            }
175
176            // Debanding is tied to tonemapping in the shader, cannot run without it.
177            if key.contains(SpritePipelineKey::DEBAND_DITHER) {
178                shader_defs.push("DEBAND_DITHER".into());
179            }
180        }
181
182        let format = match key.contains(SpritePipelineKey::HDR) {
183            true => ViewTarget::TEXTURE_FORMAT_HDR,
184            false => TextureFormat::bevy_default(),
185        };
186
187        let instance_rate_vertex_buffer_layout = VertexBufferLayout {
188            array_stride: 80,
189            step_mode: VertexStepMode::Instance,
190            attributes: vec![
191                // @location(0) i_model_transpose_col0: vec4<f32>,
192                VertexAttribute {
193                    format: VertexFormat::Float32x4,
194                    offset: 0,
195                    shader_location: 0,
196                },
197                // @location(1) i_model_transpose_col1: vec4<f32>,
198                VertexAttribute {
199                    format: VertexFormat::Float32x4,
200                    offset: 16,
201                    shader_location: 1,
202                },
203                // @location(2) i_model_transpose_col2: vec4<f32>,
204                VertexAttribute {
205                    format: VertexFormat::Float32x4,
206                    offset: 32,
207                    shader_location: 2,
208                },
209                // @location(3) i_color: vec4<f32>,
210                VertexAttribute {
211                    format: VertexFormat::Float32x4,
212                    offset: 48,
213                    shader_location: 3,
214                },
215                // @location(4) i_uv_offset_scale: vec4<f32>,
216                VertexAttribute {
217                    format: VertexFormat::Float32x4,
218                    offset: 64,
219                    shader_location: 4,
220                },
221            ],
222        };
223
224        RenderPipelineDescriptor {
225            vertex: VertexState {
226                shader: self.shader.clone(),
227                shader_defs: shader_defs.clone(),
228                buffers: vec![instance_rate_vertex_buffer_layout],
229                ..default()
230            },
231            fragment: Some(FragmentState {
232                shader: self.shader.clone(),
233                shader_defs,
234                targets: vec![Some(ColorTargetState {
235                    format,
236                    blend: Some(BlendState::ALPHA_BLENDING),
237                    write_mask: ColorWrites::ALL,
238                })],
239                ..default()
240            }),
241            layout: vec![self.view_layout.clone(), self.material_layout.clone()],
242            // Sprites are always alpha blended so they never need to write to depth.
243            // They just need to read it in case an opaque mesh2d
244            // that wrote to depth is present.
245            depth_stencil: Some(DepthStencilState {
246                format: CORE_2D_DEPTH_FORMAT,
247                depth_write_enabled: false,
248                depth_compare: CompareFunction::GreaterEqual,
249                stencil: StencilState {
250                    front: StencilFaceState::IGNORE,
251                    back: StencilFaceState::IGNORE,
252                    read_mask: 0,
253                    write_mask: 0,
254                },
255                bias: DepthBiasState {
256                    constant: 0,
257                    slope_scale: 0.0,
258                    clamp: 0.0,
259                },
260            }),
261            multisample: MultisampleState {
262                count: key.msaa_samples(),
263                mask: !0,
264                alpha_to_coverage_enabled: false,
265            },
266            label: Some("sprite_pipeline".into()),
267            ..default()
268        }
269    }
270}
271
272pub struct ExtractedSlice {
273    pub offset: Vec2,
274    pub rect: Rect,
275    pub size: Vec2,
276}
277
278pub struct ExtractedSprite {
279    pub main_entity: Entity,
280    pub render_entity: Entity,
281    pub transform: GlobalTransform,
282    pub color: LinearRgba,
283    /// Change the on-screen size of the sprite
284    /// Asset ID of the [`Image`] of this sprite
285    /// PERF: storing an `AssetId` instead of `Handle<Image>` enables some optimizations (`ExtractedSprite` becomes `Copy` and doesn't need to be dropped)
286    pub image_handle_id: AssetId<Image>,
287    pub flip_x: bool,
288    pub flip_y: bool,
289    pub kind: ExtractedSpriteKind,
290}
291
292pub enum ExtractedSpriteKind {
293    /// A single sprite with custom sizing and scaling options
294    Single {
295        anchor: Vec2,
296        rect: Option<Rect>,
297        scaling_mode: Option<SpriteScalingMode>,
298        custom_size: Option<Vec2>,
299    },
300    /// Indexes into the list of [`ExtractedSlice`]s stored in the [`ExtractedSlices`] resource
301    /// Used for elements composed from multiple sprites such as text or nine-patched borders
302    Slices { indices: Range<usize> },
303}
304
305#[derive(Resource, Default)]
306pub struct ExtractedSprites {
307    pub sprites: Vec<ExtractedSprite>,
308}
309
310#[derive(Resource, Default)]
311pub struct ExtractedSlices {
312    pub slices: Vec<ExtractedSlice>,
313}
314
315#[derive(Resource, Default)]
316pub struct SpriteAssetEvents {
317    pub images: Vec<AssetEvent<Image>>,
318}
319
320pub fn extract_sprite_events(
321    mut events: ResMut<SpriteAssetEvents>,
322    mut image_events: Extract<MessageReader<AssetEvent<Image>>>,
323) {
324    let SpriteAssetEvents { ref mut images } = *events;
325    images.clear();
326
327    for event in image_events.read() {
328        images.push(*event);
329    }
330}
331
332pub fn extract_sprites(
333    mut extracted_sprites: ResMut<ExtractedSprites>,
334    mut extracted_slices: ResMut<ExtractedSlices>,
335    texture_atlases: Extract<Res<Assets<TextureAtlasLayout>>>,
336    sprite_query: Extract<
337        Query<(
338            Entity,
339            RenderEntity,
340            &ViewVisibility,
341            &Sprite,
342            &GlobalTransform,
343            &Anchor,
344            Option<&ComputedTextureSlices>,
345        )>,
346    >,
347) {
348    extracted_sprites.sprites.clear();
349    extracted_slices.slices.clear();
350    for (main_entity, render_entity, view_visibility, sprite, transform, anchor, slices) in
351        sprite_query.iter()
352    {
353        if !view_visibility.get() {
354            continue;
355        }
356
357        if let Some(slices) = slices {
358            let start = extracted_slices.slices.len();
359            extracted_slices
360                .slices
361                .extend(slices.extract_slices(sprite, anchor.as_vec()));
362            let end = extracted_slices.slices.len();
363            extracted_sprites.sprites.push(ExtractedSprite {
364                main_entity,
365                render_entity,
366                color: sprite.color.into(),
367                transform: *transform,
368                flip_x: sprite.flip_x,
369                flip_y: sprite.flip_y,
370                image_handle_id: sprite.image.id(),
371                kind: ExtractedSpriteKind::Slices {
372                    indices: start..end,
373                },
374            });
375        } else {
376            let atlas_rect = sprite
377                .texture_atlas
378                .as_ref()
379                .and_then(|s| s.texture_rect(&texture_atlases).map(|r| r.as_rect()));
380            let rect = match (atlas_rect, sprite.rect) {
381                (None, None) => None,
382                (None, Some(sprite_rect)) => Some(sprite_rect),
383                (Some(atlas_rect), None) => Some(atlas_rect),
384                (Some(atlas_rect), Some(mut sprite_rect)) => {
385                    sprite_rect.min += atlas_rect.min;
386                    sprite_rect.max += atlas_rect.min;
387                    Some(sprite_rect)
388                }
389            };
390
391            // PERF: we don't check in this function that the `Image` asset is ready, since it should be in most cases and hashing the handle is expensive
392            extracted_sprites.sprites.push(ExtractedSprite {
393                main_entity,
394                render_entity,
395                color: sprite.color.into(),
396                transform: *transform,
397                flip_x: sprite.flip_x,
398                flip_y: sprite.flip_y,
399                image_handle_id: sprite.image.id(),
400                kind: ExtractedSpriteKind::Single {
401                    anchor: anchor.as_vec(),
402                    rect,
403                    scaling_mode: sprite.image_mode.scale(),
404                    // Pass the custom size
405                    custom_size: sprite.custom_size,
406                },
407            });
408        }
409    }
410}
411
412#[repr(C)]
413#[derive(Copy, Clone, Pod, Zeroable)]
414struct SpriteInstance {
415    // Affine 4x3 transposed to 3x4
416    pub i_model_transpose: [Vec4; 3],
417    pub i_color: [f32; 4],
418    pub i_uv_offset_scale: [f32; 4],
419}
420
421impl SpriteInstance {
422    #[inline]
423    fn from(transform: &Affine3A, color: &LinearRgba, uv_offset_scale: &Vec4) -> Self {
424        let transpose_model_3x3 = transform.matrix3.transpose();
425        Self {
426            i_model_transpose: [
427                transpose_model_3x3.x_axis.extend(transform.translation.x),
428                transpose_model_3x3.y_axis.extend(transform.translation.y),
429                transpose_model_3x3.z_axis.extend(transform.translation.z),
430            ],
431            i_color: color.to_f32_array(),
432            i_uv_offset_scale: uv_offset_scale.to_array(),
433        }
434    }
435}
436
437#[derive(Resource)]
438pub struct SpriteMeta {
439    sprite_index_buffer: RawBufferVec<u32>,
440    sprite_instance_buffer: RawBufferVec<SpriteInstance>,
441}
442
443impl Default for SpriteMeta {
444    fn default() -> Self {
445        Self {
446            sprite_index_buffer: RawBufferVec::<u32>::new(BufferUsages::INDEX),
447            sprite_instance_buffer: RawBufferVec::<SpriteInstance>::new(BufferUsages::VERTEX),
448        }
449    }
450}
451
452#[derive(Component)]
453pub struct SpriteViewBindGroup {
454    pub value: BindGroup,
455}
456
457#[derive(Resource, Deref, DerefMut, Default)]
458pub struct SpriteBatches(HashMap<(RetainedViewEntity, Entity), SpriteBatch>);
459
460#[derive(PartialEq, Eq, Clone, Debug)]
461pub struct SpriteBatch {
462    image_handle_id: AssetId<Image>,
463    range: Range<u32>,
464}
465
466#[derive(Resource, Default)]
467pub struct ImageBindGroups {
468    values: HashMap<AssetId<Image>, BindGroup>,
469}
470
471pub fn queue_sprites(
472    mut view_entities: Local<FixedBitSet>,
473    draw_functions: Res<DrawFunctions<Transparent2d>>,
474    sprite_pipeline: Res<SpritePipeline>,
475    mut pipelines: ResMut<SpecializedRenderPipelines<SpritePipeline>>,
476    pipeline_cache: Res<PipelineCache>,
477    extracted_sprites: Res<ExtractedSprites>,
478    mut transparent_render_phases: ResMut<ViewSortedRenderPhases<Transparent2d>>,
479    mut views: Query<(
480        &RenderVisibleEntities,
481        &ExtractedView,
482        &Msaa,
483        Option<&Tonemapping>,
484        Option<&DebandDither>,
485    )>,
486) {
487    let draw_sprite_function = draw_functions.read().id::<DrawSprite>();
488
489    for (visible_entities, view, msaa, tonemapping, dither) in &mut views {
490        let Some(transparent_phase) = transparent_render_phases.get_mut(&view.retained_view_entity)
491        else {
492            continue;
493        };
494
495        let msaa_key = SpritePipelineKey::from_msaa_samples(msaa.samples());
496        let mut view_key = SpritePipelineKey::from_hdr(view.hdr) | msaa_key;
497
498        if !view.hdr {
499            if let Some(tonemapping) = tonemapping {
500                view_key |= SpritePipelineKey::TONEMAP_IN_SHADER;
501                view_key |= match tonemapping {
502                    Tonemapping::None => SpritePipelineKey::TONEMAP_METHOD_NONE,
503                    Tonemapping::Reinhard => SpritePipelineKey::TONEMAP_METHOD_REINHARD,
504                    Tonemapping::ReinhardLuminance => {
505                        SpritePipelineKey::TONEMAP_METHOD_REINHARD_LUMINANCE
506                    }
507                    Tonemapping::AcesFitted => SpritePipelineKey::TONEMAP_METHOD_ACES_FITTED,
508                    Tonemapping::AgX => SpritePipelineKey::TONEMAP_METHOD_AGX,
509                    Tonemapping::SomewhatBoringDisplayTransform => {
510                        SpritePipelineKey::TONEMAP_METHOD_SOMEWHAT_BORING_DISPLAY_TRANSFORM
511                    }
512                    Tonemapping::TonyMcMapface => SpritePipelineKey::TONEMAP_METHOD_TONY_MC_MAPFACE,
513                    Tonemapping::BlenderFilmic => SpritePipelineKey::TONEMAP_METHOD_BLENDER_FILMIC,
514                };
515            }
516            if let Some(DebandDither::Enabled) = dither {
517                view_key |= SpritePipelineKey::DEBAND_DITHER;
518            }
519        }
520
521        let pipeline = pipelines.specialize(&pipeline_cache, &sprite_pipeline, view_key);
522
523        view_entities.clear();
524        view_entities.extend(
525            visible_entities
526                .iter::<Sprite>()
527                .map(|(_, e)| e.index_u32() as usize),
528        );
529
530        transparent_phase
531            .items
532            .reserve(extracted_sprites.sprites.len());
533
534        for (index, extracted_sprite) in extracted_sprites.sprites.iter().enumerate() {
535            let view_index = extracted_sprite.main_entity.index_u32();
536
537            if !view_entities.contains(view_index as usize) {
538                continue;
539            }
540
541            // These items will be sorted by depth with other phase items
542            let sort_key = FloatOrd(extracted_sprite.transform.translation().z);
543
544            // Add the item to the render phase
545            transparent_phase.add(Transparent2d {
546                draw_function: draw_sprite_function,
547                pipeline,
548                entity: (
549                    extracted_sprite.render_entity,
550                    extracted_sprite.main_entity.into(),
551                ),
552                sort_key,
553                // `batch_range` is calculated in `prepare_sprite_image_bind_groups`
554                batch_range: 0..0,
555                extra_index: PhaseItemExtraIndex::None,
556                extracted_index: index,
557                indexed: true,
558            });
559        }
560    }
561}
562
563pub fn prepare_sprite_view_bind_groups(
564    mut commands: Commands,
565    render_device: Res<RenderDevice>,
566    pipeline_cache: Res<PipelineCache>,
567    sprite_pipeline: Res<SpritePipeline>,
568    view_uniforms: Res<ViewUniforms>,
569    views: Query<(Entity, &Tonemapping), With<ExtractedView>>,
570    tonemapping_luts: Res<TonemappingLuts>,
571    images: Res<RenderAssets<GpuImage>>,
572    fallback_image: Res<FallbackImage>,
573) {
574    let Some(view_binding) = view_uniforms.uniforms.binding() else {
575        return;
576    };
577
578    for (entity, tonemapping) in &views {
579        let lut_bindings =
580            get_lut_bindings(&images, &tonemapping_luts, tonemapping, &fallback_image);
581        let view_bind_group = render_device.create_bind_group(
582            "mesh2d_view_bind_group",
583            &pipeline_cache.get_bind_group_layout(&sprite_pipeline.view_layout),
584            &BindGroupEntries::sequential((view_binding.clone(), lut_bindings.0, lut_bindings.1)),
585        );
586
587        commands.entity(entity).insert(SpriteViewBindGroup {
588            value: view_bind_group,
589        });
590    }
591}
592
593pub fn prepare_sprite_image_bind_groups(
594    render_device: Res<RenderDevice>,
595    render_queue: Res<RenderQueue>,
596    pipeline_cache: Res<PipelineCache>,
597    mut sprite_meta: ResMut<SpriteMeta>,
598    sprite_pipeline: Res<SpritePipeline>,
599    mut image_bind_groups: ResMut<ImageBindGroups>,
600    gpu_images: Res<RenderAssets<GpuImage>>,
601    extracted_sprites: Res<ExtractedSprites>,
602    extracted_slices: Res<ExtractedSlices>,
603    mut phases: ResMut<ViewSortedRenderPhases<Transparent2d>>,
604    events: Res<SpriteAssetEvents>,
605    mut batches: ResMut<SpriteBatches>,
606) {
607    // If an image has changed, the GpuImage has (probably) changed
608    for event in &events.images {
609        match event {
610            AssetEvent::Added { .. } |
611            // Images don't have dependencies
612            AssetEvent::LoadedWithDependencies { .. } => {}
613            AssetEvent::Unused { id } | AssetEvent::Modified { id } | AssetEvent::Removed { id } => {
614                image_bind_groups.values.remove(id);
615            }
616        };
617    }
618
619    batches.clear();
620
621    // Clear the sprite instances
622    sprite_meta.sprite_instance_buffer.clear();
623
624    // Index buffer indices
625    let mut index = 0;
626
627    let image_bind_groups = &mut *image_bind_groups;
628
629    for (retained_view, transparent_phase) in phases.iter_mut() {
630        let mut current_batch = None;
631        let mut batch_item_index = 0;
632        let mut batch_image_size = Vec2::ZERO;
633        let mut batch_image_handle = AssetId::invalid();
634
635        // Iterate through the phase items and detect when successive sprites that can be batched.
636        // Spawn an entity with a `SpriteBatch` component for each possible batch.
637        // Compatible items share the same entity.
638        for item_index in 0..transparent_phase.items.len() {
639            let item = &transparent_phase.items[item_index];
640
641            let Some(extracted_sprite) = extracted_sprites
642                .sprites
643                .get(item.extracted_index)
644                .filter(|extracted_sprite| extracted_sprite.render_entity == item.entity())
645            else {
646                // If there is a phase item that is not a sprite, then we must start a new
647                // batch to draw the other phase item(s) and to respect draw order. This can be
648                // done by invalidating the batch_image_handle
649                batch_image_handle = AssetId::invalid();
650                continue;
651            };
652
653            if batch_image_handle != extracted_sprite.image_handle_id {
654                let Some(gpu_image) = gpu_images.get(extracted_sprite.image_handle_id) else {
655                    continue;
656                };
657
658                batch_image_size = gpu_image.size_2d().as_vec2();
659                batch_image_handle = extracted_sprite.image_handle_id;
660                image_bind_groups
661                    .values
662                    .entry(batch_image_handle)
663                    .or_insert_with(|| {
664                        render_device.create_bind_group(
665                            "sprite_material_bind_group",
666                            &pipeline_cache.get_bind_group_layout(&sprite_pipeline.material_layout),
667                            &BindGroupEntries::sequential((
668                                &gpu_image.texture_view,
669                                &gpu_image.sampler,
670                            )),
671                        )
672                    });
673
674                batch_item_index = item_index;
675                current_batch = Some(batches.entry((*retained_view, item.entity())).insert(
676                    SpriteBatch {
677                        image_handle_id: batch_image_handle,
678                        range: index..index,
679                    },
680                ));
681            }
682            match extracted_sprite.kind {
683                ExtractedSpriteKind::Single {
684                    anchor,
685                    rect,
686                    scaling_mode,
687                    custom_size,
688                } => {
689                    // By default, the size of the quad is the size of the texture
690                    let mut quad_size = batch_image_size;
691                    let mut texture_size = batch_image_size;
692
693                    // Calculate vertex data for this item
694                    // If a rect is specified, adjust UVs and the size of the quad
695                    let mut uv_offset_scale = if let Some(rect) = rect {
696                        let rect_size = rect.size();
697                        quad_size = rect_size;
698                        // Update texture size to the rect size
699                        // It will help scale properly only portion of the image
700                        texture_size = rect_size;
701                        Vec4::new(
702                            rect.min.x / batch_image_size.x,
703                            rect.max.y / batch_image_size.y,
704                            rect_size.x / batch_image_size.x,
705                            -rect_size.y / batch_image_size.y,
706                        )
707                    } else {
708                        Vec4::new(0.0, 1.0, 1.0, -1.0)
709                    };
710
711                    if extracted_sprite.flip_x {
712                        uv_offset_scale.x += uv_offset_scale.z;
713                        uv_offset_scale.z *= -1.0;
714                    }
715                    if extracted_sprite.flip_y {
716                        uv_offset_scale.y += uv_offset_scale.w;
717                        uv_offset_scale.w *= -1.0;
718                    }
719
720                    // Override the size if a custom one is specified
721                    quad_size = custom_size.unwrap_or(quad_size);
722
723                    // Used for translation of the quad if `TextureScale::Fit...` is specified.
724                    let mut quad_translation = Vec2::ZERO;
725
726                    // Scales the texture based on the `texture_scale` field.
727                    if let Some(scaling_mode) = scaling_mode {
728                        apply_scaling(
729                            scaling_mode,
730                            texture_size,
731                            &mut quad_size,
732                            &mut quad_translation,
733                            &mut uv_offset_scale,
734                        );
735                    }
736
737                    let transform = extracted_sprite.transform.affine()
738                        * Affine3A::from_scale_rotation_translation(
739                            quad_size.extend(1.0),
740                            Quat::IDENTITY,
741                            ((quad_size + quad_translation) * (-anchor - Vec2::splat(0.5)))
742                                .extend(0.0),
743                        );
744
745                    // Store the vertex data and add the item to the render phase
746                    sprite_meta
747                        .sprite_instance_buffer
748                        .push(SpriteInstance::from(
749                            &transform,
750                            &extracted_sprite.color,
751                            &uv_offset_scale,
752                        ));
753
754                    current_batch.as_mut().unwrap().get_mut().range.end += 1;
755                    index += 1;
756                }
757                ExtractedSpriteKind::Slices { ref indices } => {
758                    for i in indices.clone() {
759                        let slice = &extracted_slices.slices[i];
760                        let rect = slice.rect;
761                        let rect_size = rect.size();
762
763                        // Calculate vertex data for this item
764                        let mut uv_offset_scale: Vec4;
765
766                        // If a rect is specified, adjust UVs and the size of the quad
767                        uv_offset_scale = Vec4::new(
768                            rect.min.x / batch_image_size.x,
769                            rect.max.y / batch_image_size.y,
770                            rect_size.x / batch_image_size.x,
771                            -rect_size.y / batch_image_size.y,
772                        );
773
774                        if extracted_sprite.flip_x {
775                            uv_offset_scale.x += uv_offset_scale.z;
776                            uv_offset_scale.z *= -1.0;
777                        }
778                        if extracted_sprite.flip_y {
779                            uv_offset_scale.y += uv_offset_scale.w;
780                            uv_offset_scale.w *= -1.0;
781                        }
782
783                        let transform = extracted_sprite.transform.affine()
784                            * Affine3A::from_scale_rotation_translation(
785                                slice.size.extend(1.0),
786                                Quat::IDENTITY,
787                                (slice.size * -Vec2::splat(0.5) + slice.offset).extend(0.0),
788                            );
789
790                        // Store the vertex data and add the item to the render phase
791                        sprite_meta
792                            .sprite_instance_buffer
793                            .push(SpriteInstance::from(
794                                &transform,
795                                &extracted_sprite.color,
796                                &uv_offset_scale,
797                            ));
798
799                        current_batch.as_mut().unwrap().get_mut().range.end += 1;
800                        index += 1;
801                    }
802                }
803            }
804            transparent_phase.items[batch_item_index]
805                .batch_range_mut()
806                .end += 1;
807        }
808        sprite_meta
809            .sprite_instance_buffer
810            .write_buffer(&render_device, &render_queue);
811
812        if sprite_meta.sprite_index_buffer.len() != 6 {
813            sprite_meta.sprite_index_buffer.clear();
814
815            // NOTE: This code is creating 6 indices pointing to 4 vertices.
816            // The vertices form the corners of a quad based on their two least significant bits.
817            // 10   11
818            //
819            // 00   01
820            // The sprite shader can then use the two least significant bits as the vertex index.
821            // The rest of the properties to transform the vertex positions and UVs (which are
822            // implicit) are baked into the instance transform, and UV offset and scale.
823            // See bevy_sprite_render/src/render/sprite.wgsl for the details.
824            sprite_meta.sprite_index_buffer.push(2);
825            sprite_meta.sprite_index_buffer.push(0);
826            sprite_meta.sprite_index_buffer.push(1);
827            sprite_meta.sprite_index_buffer.push(1);
828            sprite_meta.sprite_index_buffer.push(3);
829            sprite_meta.sprite_index_buffer.push(2);
830
831            sprite_meta
832                .sprite_index_buffer
833                .write_buffer(&render_device, &render_queue);
834        }
835    }
836}
837/// [`RenderCommand`] for sprite rendering.
838pub type DrawSprite = (
839    SetItemPipeline,
840    SetSpriteViewBindGroup<0>,
841    SetSpriteTextureBindGroup<1>,
842    DrawSpriteBatch,
843);
844
845pub struct SetSpriteViewBindGroup<const I: usize>;
846impl<P: PhaseItem, const I: usize> RenderCommand<P> for SetSpriteViewBindGroup<I> {
847    type Param = ();
848    type ViewQuery = (Read<ViewUniformOffset>, Read<SpriteViewBindGroup>);
849    type ItemQuery = ();
850
851    fn render<'w>(
852        _item: &P,
853        (view_uniform, sprite_view_bind_group): ROQueryItem<'w, '_, Self::ViewQuery>,
854        _entity: Option<()>,
855        _param: SystemParamItem<'w, '_, Self::Param>,
856        pass: &mut TrackedRenderPass<'w>,
857    ) -> RenderCommandResult {
858        pass.set_bind_group(I, &sprite_view_bind_group.value, &[view_uniform.offset]);
859        RenderCommandResult::Success
860    }
861}
862pub struct SetSpriteTextureBindGroup<const I: usize>;
863impl<P: PhaseItem, const I: usize> RenderCommand<P> for SetSpriteTextureBindGroup<I> {
864    type Param = (SRes<ImageBindGroups>, SRes<SpriteBatches>);
865    type ViewQuery = Read<ExtractedView>;
866    type ItemQuery = ();
867
868    fn render<'w>(
869        item: &P,
870        view: ROQueryItem<'w, '_, Self::ViewQuery>,
871        _entity: Option<()>,
872        (image_bind_groups, batches): SystemParamItem<'w, '_, Self::Param>,
873        pass: &mut TrackedRenderPass<'w>,
874    ) -> RenderCommandResult {
875        let image_bind_groups = image_bind_groups.into_inner();
876        let Some(batch) = batches.get(&(view.retained_view_entity, item.entity())) else {
877            return RenderCommandResult::Skip;
878        };
879
880        pass.set_bind_group(
881            I,
882            image_bind_groups
883                .values
884                .get(&batch.image_handle_id)
885                .unwrap(),
886            &[],
887        );
888        RenderCommandResult::Success
889    }
890}
891
892pub struct DrawSpriteBatch;
893impl<P: PhaseItem> RenderCommand<P> for DrawSpriteBatch {
894    type Param = (SRes<SpriteMeta>, SRes<SpriteBatches>);
895    type ViewQuery = Read<ExtractedView>;
896    type ItemQuery = ();
897
898    fn render<'w>(
899        item: &P,
900        view: ROQueryItem<'w, '_, Self::ViewQuery>,
901        _entity: Option<()>,
902        (sprite_meta, batches): SystemParamItem<'w, '_, Self::Param>,
903        pass: &mut TrackedRenderPass<'w>,
904    ) -> RenderCommandResult {
905        let sprite_meta = sprite_meta.into_inner();
906        let Some(batch) = batches.get(&(view.retained_view_entity, item.entity())) else {
907            return RenderCommandResult::Skip;
908        };
909
910        pass.set_index_buffer(
911            sprite_meta.sprite_index_buffer.buffer().unwrap().slice(..),
912            IndexFormat::Uint32,
913        );
914        pass.set_vertex_buffer(
915            0,
916            sprite_meta
917                .sprite_instance_buffer
918                .buffer()
919                .unwrap()
920                .slice(..),
921        );
922        pass.draw_indexed(0..6, 0, batch.range.clone());
923        RenderCommandResult::Success
924    }
925}
926
927/// Scales a texture to fit within a given quad size with keeping the aspect ratio.
928fn apply_scaling(
929    scaling_mode: SpriteScalingMode,
930    texture_size: Vec2,
931    quad_size: &mut Vec2,
932    quad_translation: &mut Vec2,
933    uv_offset_scale: &mut Vec4,
934) {
935    let quad_ratio = quad_size.x / quad_size.y;
936    let texture_ratio = texture_size.x / texture_size.y;
937    let tex_quad_scale = texture_ratio / quad_ratio;
938    let quad_tex_scale = quad_ratio / texture_ratio;
939
940    match scaling_mode {
941        SpriteScalingMode::FillCenter => {
942            if quad_ratio > texture_ratio {
943                // offset texture to center by y coordinate
944                uv_offset_scale.y += (uv_offset_scale.w - uv_offset_scale.w * tex_quad_scale) * 0.5;
945                // sum up scales
946                uv_offset_scale.w *= tex_quad_scale;
947            } else {
948                // offset texture to center by x coordinate
949                uv_offset_scale.x += (uv_offset_scale.z - uv_offset_scale.z * quad_tex_scale) * 0.5;
950                uv_offset_scale.z *= quad_tex_scale;
951            };
952        }
953        SpriteScalingMode::FillStart => {
954            if quad_ratio > texture_ratio {
955                uv_offset_scale.y += uv_offset_scale.w - uv_offset_scale.w * tex_quad_scale;
956                uv_offset_scale.w *= tex_quad_scale;
957            } else {
958                uv_offset_scale.z *= quad_tex_scale;
959            }
960        }
961        SpriteScalingMode::FillEnd => {
962            if quad_ratio > texture_ratio {
963                uv_offset_scale.w *= tex_quad_scale;
964            } else {
965                uv_offset_scale.x += uv_offset_scale.z - uv_offset_scale.z * quad_tex_scale;
966                uv_offset_scale.z *= quad_tex_scale;
967            }
968        }
969        SpriteScalingMode::FitCenter => {
970            if texture_ratio > quad_ratio {
971                // Scale based on width
972                quad_size.y *= quad_tex_scale;
973            } else {
974                // Scale based on height
975                quad_size.x *= tex_quad_scale;
976            }
977        }
978        SpriteScalingMode::FitStart => {
979            if texture_ratio > quad_ratio {
980                // The quad is scaled to match the image ratio, and the quad translation is adjusted
981                // to start of the quad within the original quad size.
982                let scale = Vec2::new(1.0, quad_tex_scale);
983                let new_quad = *quad_size * scale;
984                let offset = *quad_size - new_quad;
985                *quad_translation = Vec2::new(0.0, -offset.y);
986                *quad_size = new_quad;
987            } else {
988                let scale = Vec2::new(tex_quad_scale, 1.0);
989                let new_quad = *quad_size * scale;
990                let offset = *quad_size - new_quad;
991                *quad_translation = Vec2::new(offset.x, 0.0);
992                *quad_size = new_quad;
993            }
994        }
995        SpriteScalingMode::FitEnd => {
996            if texture_ratio > quad_ratio {
997                let scale = Vec2::new(1.0, quad_tex_scale);
998                let new_quad = *quad_size * scale;
999                let offset = *quad_size - new_quad;
1000                *quad_translation = Vec2::new(0.0, offset.y);
1001                *quad_size = new_quad;
1002            } else {
1003                let scale = Vec2::new(tex_quad_scale, 1.0);
1004                let new_quad = *quad_size * scale;
1005                let offset = *quad_size - new_quad;
1006                *quad_translation = Vec2::new(-offset.x, 0.0);
1007                *quad_size = new_quad;
1008            }
1009        }
1010    }
1011}