bevy_sprite_render/mesh2d/
material.rs

1use crate::{
2    init_mesh_2d_pipeline, DrawMesh2d, Mesh2d, Mesh2dPipeline, Mesh2dPipelineKey,
3    RenderMesh2dInstances, SetMesh2dBindGroup, SetMesh2dViewBindGroup, ViewKeyCache,
4    ViewSpecializationTicks,
5};
6use bevy_app::{App, Plugin, PostUpdate};
7use bevy_asset::prelude::AssetChanged;
8use bevy_asset::{AsAssetId, Asset, AssetApp, AssetEventSystems, AssetId, AssetServer, Handle};
9use bevy_camera::visibility::ViewVisibility;
10use bevy_core_pipeline::{
11    core_2d::{
12        AlphaMask2d, AlphaMask2dBinKey, BatchSetKey2d, Opaque2d, Opaque2dBinKey, Transparent2d,
13    },
14    tonemapping::Tonemapping,
15};
16use bevy_derive::{Deref, DerefMut};
17use bevy_ecs::component::Tick;
18use bevy_ecs::system::SystemChangeTick;
19use bevy_ecs::{
20    prelude::*,
21    system::{lifetimeless::SRes, SystemParamItem},
22};
23use bevy_math::FloatOrd;
24use bevy_mesh::MeshVertexBufferLayoutRef;
25use bevy_platform::collections::HashMap;
26use bevy_reflect::{prelude::ReflectDefault, Reflect};
27use bevy_render::{
28    camera::extract_cameras,
29    mesh::RenderMesh,
30    render_asset::{
31        prepare_assets, PrepareAssetError, RenderAsset, RenderAssetPlugin, RenderAssets,
32    },
33    render_phase::{
34        AddRenderCommand, BinnedRenderPhaseType, DrawFunctionId, DrawFunctions, InputUniformIndex,
35        PhaseItem, PhaseItemExtraIndex, RenderCommand, RenderCommandResult, SetItemPipeline,
36        TrackedRenderPass, ViewBinnedRenderPhases, ViewSortedRenderPhases,
37    },
38    render_resource::{
39        AsBindGroup, AsBindGroupError, BindGroup, BindGroupId, BindGroupLayout, BindingResources,
40        CachedRenderPipelineId, PipelineCache, RenderPipelineDescriptor, SpecializedMeshPipeline,
41        SpecializedMeshPipelineError, SpecializedMeshPipelines,
42    },
43    renderer::RenderDevice,
44    sync_world::{MainEntity, MainEntityHashMap},
45    view::{ExtractedView, RenderVisibleEntities},
46    Extract, ExtractSchedule, Render, RenderApp, RenderStartup, RenderSystems,
47};
48use bevy_shader::{Shader, ShaderDefVal, ShaderRef};
49use bevy_utils::Parallel;
50use core::{hash::Hash, marker::PhantomData};
51use derive_more::derive::From;
52use tracing::error;
53
54pub const MATERIAL_2D_BIND_GROUP_INDEX: usize = 2;
55
56/// Materials are used alongside [`Material2dPlugin`], [`Mesh2d`], and [`MeshMaterial2d`]
57/// to spawn entities that are rendered with a specific [`Material2d`] type. They serve as an easy to use high level
58/// way to render [`Mesh2d`] entities with custom shader logic.
59///
60/// Materials must implement [`AsBindGroup`] to define how data will be transferred to the GPU and bound in shaders.
61/// [`AsBindGroup`] can be derived, which makes generating bindings straightforward. See the [`AsBindGroup`] docs for details.
62///
63/// # Example
64///
65/// Here is a simple [`Material2d`] implementation. The [`AsBindGroup`] derive has many features. To see what else is available,
66/// check out the [`AsBindGroup`] documentation.
67///
68/// ```
69/// # use bevy_sprite_render::{Material2d, MeshMaterial2d};
70/// # use bevy_ecs::prelude::*;
71/// # use bevy_image::Image;
72/// # use bevy_reflect::TypePath;
73/// # use bevy_mesh::{Mesh, Mesh2d};
74/// # use bevy_render::render_resource::AsBindGroup;
75/// # use bevy_shader::ShaderRef;
76/// # use bevy_color::LinearRgba;
77/// # use bevy_color::palettes::basic::RED;
78/// # use bevy_asset::{Handle, AssetServer, Assets, Asset};
79/// # use bevy_math::primitives::Circle;
80/// #
81/// #[derive(AsBindGroup, Debug, Clone, Asset, TypePath)]
82/// pub struct CustomMaterial {
83///     // Uniform bindings must implement `ShaderType`, which will be used to convert the value to
84///     // its shader-compatible equivalent. Most core math types already implement `ShaderType`.
85///     #[uniform(0)]
86///     color: LinearRgba,
87///     // Images can be bound as textures in shaders. If the Image's sampler is also needed, just
88///     // add the sampler attribute with a different binding index.
89///     #[texture(1)]
90///     #[sampler(2)]
91///     color_texture: Handle<Image>,
92/// }
93///
94/// // All functions on `Material2d` have default impls. You only need to implement the
95/// // functions that are relevant for your material.
96/// impl Material2d for CustomMaterial {
97///     fn fragment_shader() -> ShaderRef {
98///         "shaders/custom_material.wgsl".into()
99///     }
100/// }
101///
102/// // Spawn an entity with a mesh using `CustomMaterial`.
103/// fn setup(
104///     mut commands: Commands,
105///     mut meshes: ResMut<Assets<Mesh>>,
106///     mut materials: ResMut<Assets<CustomMaterial>>,
107///     asset_server: Res<AssetServer>,
108/// ) {
109///     commands.spawn((
110///         Mesh2d(meshes.add(Circle::new(50.0))),
111///         MeshMaterial2d(materials.add(CustomMaterial {
112///             color: RED.into(),
113///             color_texture: asset_server.load("some_image.png"),
114///         })),
115///     ));
116/// }
117/// ```
118///
119/// In WGSL shaders, the material's binding would look like this:
120///
121/// ```wgsl
122/// struct CustomMaterial {
123///     color: vec4<f32>,
124/// }
125///
126/// @group(2) @binding(0) var<uniform> material: CustomMaterial;
127/// @group(2) @binding(1) var color_texture: texture_2d<f32>;
128/// @group(2) @binding(2) var color_sampler: sampler;
129/// ```
130pub trait Material2d: AsBindGroup + Asset + Clone + Sized {
131    /// Returns this material's vertex shader. If [`ShaderRef::Default`] is returned, the default mesh vertex shader
132    /// will be used.
133    fn vertex_shader() -> ShaderRef {
134        ShaderRef::Default
135    }
136
137    /// Returns this material's fragment shader. If [`ShaderRef::Default`] is returned, the default mesh fragment shader
138    /// will be used.
139    fn fragment_shader() -> ShaderRef {
140        ShaderRef::Default
141    }
142
143    /// Add a bias to the view depth of the mesh which can be used to force a specific render order.
144    #[inline]
145    fn depth_bias(&self) -> f32 {
146        0.0
147    }
148
149    fn alpha_mode(&self) -> AlphaMode2d {
150        AlphaMode2d::Opaque
151    }
152
153    /// Customizes the default [`RenderPipelineDescriptor`].
154    #[expect(
155        unused_variables,
156        reason = "The parameters here are intentionally unused by the default implementation; however, putting underscores here will result in the underscores being copied by rust-analyzer's tab completion."
157    )]
158    #[inline]
159    fn specialize(
160        descriptor: &mut RenderPipelineDescriptor,
161        layout: &MeshVertexBufferLayoutRef,
162        key: Material2dKey<Self>,
163    ) -> Result<(), SpecializedMeshPipelineError> {
164        Ok(())
165    }
166}
167
168/// A [material](Material2d) used for rendering a [`Mesh2d`].
169///
170/// See [`Material2d`] for general information about 2D materials and how to implement your own materials.
171///
172/// # Example
173///
174/// ```
175/// # use bevy_sprite_render::{ColorMaterial, MeshMaterial2d};
176/// # use bevy_ecs::prelude::*;
177/// # use bevy_mesh::{Mesh, Mesh2d};
178/// # use bevy_color::palettes::basic::RED;
179/// # use bevy_asset::Assets;
180/// # use bevy_math::primitives::Circle;
181/// #
182/// // Spawn an entity with a mesh using `ColorMaterial`.
183/// fn setup(
184///     mut commands: Commands,
185///     mut meshes: ResMut<Assets<Mesh>>,
186///     mut materials: ResMut<Assets<ColorMaterial>>,
187/// ) {
188///     commands.spawn((
189///         Mesh2d(meshes.add(Circle::new(50.0))),
190///         MeshMaterial2d(materials.add(ColorMaterial::from_color(RED))),
191///     ));
192/// }
193/// ```
194///
195/// [`MeshMaterial2d`]: crate::MeshMaterial2d
196#[derive(Component, Clone, Debug, Deref, DerefMut, Reflect, From)]
197#[reflect(Component, Default, Clone)]
198pub struct MeshMaterial2d<M: Material2d>(pub Handle<M>);
199
200impl<M: Material2d> Default for MeshMaterial2d<M> {
201    fn default() -> Self {
202        Self(Handle::default())
203    }
204}
205
206impl<M: Material2d> PartialEq for MeshMaterial2d<M> {
207    fn eq(&self, other: &Self) -> bool {
208        self.0 == other.0
209    }
210}
211
212impl<M: Material2d> Eq for MeshMaterial2d<M> {}
213
214impl<M: Material2d> From<MeshMaterial2d<M>> for AssetId<M> {
215    fn from(material: MeshMaterial2d<M>) -> Self {
216        material.id()
217    }
218}
219
220impl<M: Material2d> From<&MeshMaterial2d<M>> for AssetId<M> {
221    fn from(material: &MeshMaterial2d<M>) -> Self {
222        material.id()
223    }
224}
225
226impl<M: Material2d> AsAssetId for MeshMaterial2d<M> {
227    type Asset = M;
228
229    fn as_asset_id(&self) -> AssetId<Self::Asset> {
230        self.id()
231    }
232}
233
234/// Sets how a 2d material's base color alpha channel is used for transparency.
235/// Currently, this only works with [`Mesh2d`]. Sprites are always transparent.
236///
237/// This is very similar to [`AlphaMode`](bevy_render::alpha::AlphaMode) but this only applies to 2d meshes.
238/// We use a separate type because 2d doesn't support all the transparency modes that 3d does.
239#[derive(Debug, Default, Reflect, Copy, Clone, PartialEq)]
240#[reflect(Default, Debug, Clone)]
241pub enum AlphaMode2d {
242    /// Base color alpha values are overridden to be fully opaque (1.0).
243    #[default]
244    Opaque,
245    /// Reduce transparency to fully opaque or fully transparent
246    /// based on a threshold.
247    ///
248    /// Compares the base color alpha value to the specified threshold.
249    /// If the value is below the threshold,
250    /// considers the color to be fully transparent (alpha is set to 0.0).
251    /// If it is equal to or above the threshold,
252    /// considers the color to be fully opaque (alpha is set to 1.0).
253    Mask(f32),
254    /// The base color alpha value defines the opacity of the color.
255    /// Standard alpha-blending is used to blend the fragment's color
256    /// with the color behind it.
257    Blend,
258}
259
260/// Adds the necessary ECS resources and render logic to enable rendering entities using the given [`Material2d`]
261/// asset type (which includes [`Material2d`] types).
262pub struct Material2dPlugin<M: Material2d>(PhantomData<M>);
263
264impl<M: Material2d> Default for Material2dPlugin<M> {
265    fn default() -> Self {
266        Self(Default::default())
267    }
268}
269
270impl<M: Material2d> Plugin for Material2dPlugin<M>
271where
272    M::Data: PartialEq + Eq + Hash + Clone,
273{
274    fn build(&self, app: &mut App) {
275        app.init_asset::<M>()
276            .init_resource::<EntitiesNeedingSpecialization<M>>()
277            .register_type::<MeshMaterial2d<M>>()
278            .add_plugins(RenderAssetPlugin::<PreparedMaterial2d<M>>::default())
279            .add_systems(
280                PostUpdate,
281                check_entities_needing_specialization::<M>.after(AssetEventSystems),
282            );
283
284        if let Some(render_app) = app.get_sub_app_mut(RenderApp) {
285            render_app
286                .init_resource::<EntitySpecializationTicks<M>>()
287                .init_resource::<SpecializedMaterial2dPipelineCache<M>>()
288                .add_render_command::<Opaque2d, DrawMaterial2d<M>>()
289                .add_render_command::<AlphaMask2d, DrawMaterial2d<M>>()
290                .add_render_command::<Transparent2d, DrawMaterial2d<M>>()
291                .init_resource::<RenderMaterial2dInstances<M>>()
292                .init_resource::<SpecializedMeshPipelines<Material2dPipeline<M>>>()
293                .add_systems(
294                    RenderStartup,
295                    init_material_2d_pipeline::<M>.after(init_mesh_2d_pipeline),
296                )
297                .add_systems(
298                    ExtractSchedule,
299                    (
300                        extract_entities_needs_specialization::<M>.after(extract_cameras),
301                        extract_mesh_materials_2d::<M>,
302                    ),
303                )
304                .add_systems(
305                    Render,
306                    (
307                        specialize_material2d_meshes::<M>
308                            .in_set(RenderSystems::PrepareMeshes)
309                            .after(prepare_assets::<PreparedMaterial2d<M>>)
310                            .after(prepare_assets::<RenderMesh>),
311                        queue_material2d_meshes::<M>
312                            .in_set(RenderSystems::QueueMeshes)
313                            .after(prepare_assets::<PreparedMaterial2d<M>>),
314                    ),
315                );
316        }
317    }
318}
319
320#[derive(Resource, Deref, DerefMut)]
321pub struct RenderMaterial2dInstances<M: Material2d>(MainEntityHashMap<AssetId<M>>);
322
323impl<M: Material2d> Default for RenderMaterial2dInstances<M> {
324    fn default() -> Self {
325        Self(Default::default())
326    }
327}
328
329pub fn extract_mesh_materials_2d<M: Material2d>(
330    mut material_instances: ResMut<RenderMaterial2dInstances<M>>,
331    changed_meshes_query: Extract<
332        Query<
333            (Entity, &ViewVisibility, &MeshMaterial2d<M>),
334            Or<(Changed<ViewVisibility>, Changed<MeshMaterial2d<M>>)>,
335        >,
336    >,
337    mut removed_materials_query: Extract<RemovedComponents<MeshMaterial2d<M>>>,
338) {
339    for (entity, view_visibility, material) in &changed_meshes_query {
340        if view_visibility.get() {
341            add_mesh_instance(entity, material, &mut material_instances);
342        } else {
343            remove_mesh_instance(entity, &mut material_instances);
344        }
345    }
346
347    for entity in removed_materials_query.read() {
348        // Only queue a mesh for removal if we didn't pick it up above.
349        // It's possible that a necessary component was removed and re-added in
350        // the same frame.
351        if !changed_meshes_query.contains(entity) {
352            remove_mesh_instance(entity, &mut material_instances);
353        }
354    }
355
356    // Adds or updates a mesh instance in the [`RenderMaterial2dInstances`]
357    // array.
358    fn add_mesh_instance<M>(
359        entity: Entity,
360        material: &MeshMaterial2d<M>,
361        material_instances: &mut RenderMaterial2dInstances<M>,
362    ) where
363        M: Material2d,
364    {
365        material_instances.insert(entity.into(), material.id());
366    }
367
368    // Removes a mesh instance from the [`RenderMaterial2dInstances`] array.
369    fn remove_mesh_instance<M>(
370        entity: Entity,
371        material_instances: &mut RenderMaterial2dInstances<M>,
372    ) where
373        M: Material2d,
374    {
375        material_instances.remove(&MainEntity::from(entity));
376    }
377}
378
379/// Render pipeline data for a given [`Material2d`]
380#[derive(Resource)]
381pub struct Material2dPipeline<M: Material2d> {
382    pub mesh2d_pipeline: Mesh2dPipeline,
383    pub material2d_layout: BindGroupLayout,
384    pub vertex_shader: Option<Handle<Shader>>,
385    pub fragment_shader: Option<Handle<Shader>>,
386    marker: PhantomData<M>,
387}
388
389pub struct Material2dKey<M: Material2d> {
390    pub mesh_key: Mesh2dPipelineKey,
391    pub bind_group_data: M::Data,
392}
393
394impl<M: Material2d> Eq for Material2dKey<M> where M::Data: PartialEq {}
395
396impl<M: Material2d> PartialEq for Material2dKey<M>
397where
398    M::Data: PartialEq,
399{
400    fn eq(&self, other: &Self) -> bool {
401        self.mesh_key == other.mesh_key && self.bind_group_data == other.bind_group_data
402    }
403}
404
405impl<M: Material2d> Clone for Material2dKey<M>
406where
407    M::Data: Clone,
408{
409    fn clone(&self) -> Self {
410        Self {
411            mesh_key: self.mesh_key,
412            bind_group_data: self.bind_group_data.clone(),
413        }
414    }
415}
416
417impl<M: Material2d> Hash for Material2dKey<M>
418where
419    M::Data: Hash,
420{
421    fn hash<H: core::hash::Hasher>(&self, state: &mut H) {
422        self.mesh_key.hash(state);
423        self.bind_group_data.hash(state);
424    }
425}
426
427impl<M: Material2d> Clone for Material2dPipeline<M> {
428    fn clone(&self) -> Self {
429        Self {
430            mesh2d_pipeline: self.mesh2d_pipeline.clone(),
431            material2d_layout: self.material2d_layout.clone(),
432            vertex_shader: self.vertex_shader.clone(),
433            fragment_shader: self.fragment_shader.clone(),
434            marker: PhantomData,
435        }
436    }
437}
438
439impl<M: Material2d> SpecializedMeshPipeline for Material2dPipeline<M>
440where
441    M::Data: PartialEq + Eq + Hash + Clone,
442{
443    type Key = Material2dKey<M>;
444
445    fn specialize(
446        &self,
447        key: Self::Key,
448        layout: &MeshVertexBufferLayoutRef,
449    ) -> Result<RenderPipelineDescriptor, SpecializedMeshPipelineError> {
450        let mut descriptor = self.mesh2d_pipeline.specialize(key.mesh_key, layout)?;
451        descriptor.vertex.shader_defs.push(ShaderDefVal::UInt(
452            "MATERIAL_BIND_GROUP".into(),
453            MATERIAL_2D_BIND_GROUP_INDEX as u32,
454        ));
455        if let Some(ref mut fragment) = descriptor.fragment {
456            fragment.shader_defs.push(ShaderDefVal::UInt(
457                "MATERIAL_BIND_GROUP".into(),
458                MATERIAL_2D_BIND_GROUP_INDEX as u32,
459            ));
460        }
461        if let Some(vertex_shader) = &self.vertex_shader {
462            descriptor.vertex.shader = vertex_shader.clone();
463        }
464
465        if let Some(fragment_shader) = &self.fragment_shader {
466            descriptor.fragment.as_mut().unwrap().shader = fragment_shader.clone();
467        }
468        descriptor.layout = vec![
469            self.mesh2d_pipeline.view_layout.clone(),
470            self.mesh2d_pipeline.mesh_layout.clone(),
471            self.material2d_layout.clone(),
472        ];
473
474        M::specialize(&mut descriptor, layout, key)?;
475        Ok(descriptor)
476    }
477}
478
479pub fn init_material_2d_pipeline<M: Material2d>(
480    mut commands: Commands,
481    render_device: Res<RenderDevice>,
482    asset_server: Res<AssetServer>,
483    mesh_2d_pipeline: Res<Mesh2dPipeline>,
484) {
485    let material2d_layout = M::bind_group_layout(&render_device);
486
487    commands.insert_resource(Material2dPipeline::<M> {
488        mesh2d_pipeline: mesh_2d_pipeline.clone(),
489        material2d_layout,
490        vertex_shader: match M::vertex_shader() {
491            ShaderRef::Default => None,
492            ShaderRef::Handle(handle) => Some(handle),
493            ShaderRef::Path(path) => Some(asset_server.load(path)),
494        },
495        fragment_shader: match M::fragment_shader() {
496            ShaderRef::Default => None,
497            ShaderRef::Handle(handle) => Some(handle),
498            ShaderRef::Path(path) => Some(asset_server.load(path)),
499        },
500        marker: PhantomData,
501    });
502}
503
504pub(super) type DrawMaterial2d<M> = (
505    SetItemPipeline,
506    SetMesh2dViewBindGroup<0>,
507    SetMesh2dBindGroup<1>,
508    SetMaterial2dBindGroup<M, MATERIAL_2D_BIND_GROUP_INDEX>,
509    DrawMesh2d,
510);
511
512pub struct SetMaterial2dBindGroup<M: Material2d, const I: usize>(PhantomData<M>);
513impl<P: PhaseItem, M: Material2d, const I: usize> RenderCommand<P>
514    for SetMaterial2dBindGroup<M, I>
515{
516    type Param = (
517        SRes<RenderAssets<PreparedMaterial2d<M>>>,
518        SRes<RenderMaterial2dInstances<M>>,
519    );
520    type ViewQuery = ();
521    type ItemQuery = ();
522
523    #[inline]
524    fn render<'w>(
525        item: &P,
526        _view: (),
527        _item_query: Option<()>,
528        (materials, material_instances): SystemParamItem<'w, '_, Self::Param>,
529        pass: &mut TrackedRenderPass<'w>,
530    ) -> RenderCommandResult {
531        let materials = materials.into_inner();
532        let material_instances = material_instances.into_inner();
533        let Some(material_instance) = material_instances.get(&item.main_entity()) else {
534            return RenderCommandResult::Skip;
535        };
536        let Some(material2d) = materials.get(*material_instance) else {
537            return RenderCommandResult::Skip;
538        };
539        pass.set_bind_group(I, &material2d.bind_group, &[]);
540        RenderCommandResult::Success
541    }
542}
543
544pub const fn alpha_mode_pipeline_key(alpha_mode: AlphaMode2d) -> Mesh2dPipelineKey {
545    match alpha_mode {
546        AlphaMode2d::Blend => Mesh2dPipelineKey::BLEND_ALPHA,
547        AlphaMode2d::Mask(_) => Mesh2dPipelineKey::MAY_DISCARD,
548        _ => Mesh2dPipelineKey::NONE,
549    }
550}
551
552pub const fn tonemapping_pipeline_key(tonemapping: Tonemapping) -> Mesh2dPipelineKey {
553    match tonemapping {
554        Tonemapping::None => Mesh2dPipelineKey::TONEMAP_METHOD_NONE,
555        Tonemapping::Reinhard => Mesh2dPipelineKey::TONEMAP_METHOD_REINHARD,
556        Tonemapping::ReinhardLuminance => Mesh2dPipelineKey::TONEMAP_METHOD_REINHARD_LUMINANCE,
557        Tonemapping::AcesFitted => Mesh2dPipelineKey::TONEMAP_METHOD_ACES_FITTED,
558        Tonemapping::AgX => Mesh2dPipelineKey::TONEMAP_METHOD_AGX,
559        Tonemapping::SomewhatBoringDisplayTransform => {
560            Mesh2dPipelineKey::TONEMAP_METHOD_SOMEWHAT_BORING_DISPLAY_TRANSFORM
561        }
562        Tonemapping::TonyMcMapface => Mesh2dPipelineKey::TONEMAP_METHOD_TONY_MC_MAPFACE,
563        Tonemapping::BlenderFilmic => Mesh2dPipelineKey::TONEMAP_METHOD_BLENDER_FILMIC,
564    }
565}
566
567pub fn extract_entities_needs_specialization<M>(
568    entities_needing_specialization: Extract<Res<EntitiesNeedingSpecialization<M>>>,
569    mut entity_specialization_ticks: ResMut<EntitySpecializationTicks<M>>,
570    mut removed_mesh_material_components: Extract<RemovedComponents<MeshMaterial2d<M>>>,
571    mut specialized_material2d_pipeline_cache: ResMut<SpecializedMaterial2dPipelineCache<M>>,
572    views: Query<&MainEntity, With<ExtractedView>>,
573    ticks: SystemChangeTick,
574) where
575    M: Material2d,
576{
577    // Clean up any despawned entities, we do this first in case the removed material was re-added
578    // the same frame, thus will appear both in the removed components list and have been added to
579    // the `EntitiesNeedingSpecialization` collection by triggering the `Changed` filter
580    for entity in removed_mesh_material_components.read() {
581        entity_specialization_ticks.remove(&MainEntity::from(entity));
582        for view in views {
583            if let Some(cache) = specialized_material2d_pipeline_cache.get_mut(view) {
584                cache.remove(&MainEntity::from(entity));
585            }
586        }
587    }
588    for entity in entities_needing_specialization.iter() {
589        // Update the entity's specialization tick with this run's tick
590        entity_specialization_ticks.insert((*entity).into(), ticks.this_run());
591    }
592}
593
594#[derive(Clone, Resource, Deref, DerefMut, Debug)]
595pub struct EntitiesNeedingSpecialization<M> {
596    #[deref]
597    pub entities: Vec<Entity>,
598    _marker: PhantomData<M>,
599}
600
601impl<M> Default for EntitiesNeedingSpecialization<M> {
602    fn default() -> Self {
603        Self {
604            entities: Default::default(),
605            _marker: Default::default(),
606        }
607    }
608}
609
610#[derive(Clone, Resource, Deref, DerefMut, Debug)]
611pub struct EntitySpecializationTicks<M> {
612    #[deref]
613    pub entities: MainEntityHashMap<Tick>,
614    _marker: PhantomData<M>,
615}
616
617impl<M> Default for EntitySpecializationTicks<M> {
618    fn default() -> Self {
619        Self {
620            entities: MainEntityHashMap::default(),
621            _marker: Default::default(),
622        }
623    }
624}
625
626/// Stores the [`SpecializedMaterial2dViewPipelineCache`] for each view.
627#[derive(Resource, Deref, DerefMut)]
628pub struct SpecializedMaterial2dPipelineCache<M> {
629    // view_entity -> view pipeline cache
630    #[deref]
631    map: MainEntityHashMap<SpecializedMaterial2dViewPipelineCache<M>>,
632    marker: PhantomData<M>,
633}
634
635/// Stores the cached render pipeline ID for each entity in a single view, as
636/// well as the last time it was changed.
637#[derive(Deref, DerefMut)]
638pub struct SpecializedMaterial2dViewPipelineCache<M> {
639    // material entity -> (tick, pipeline_id)
640    #[deref]
641    map: MainEntityHashMap<(Tick, CachedRenderPipelineId)>,
642    marker: PhantomData<M>,
643}
644
645impl<M> Default for SpecializedMaterial2dPipelineCache<M> {
646    fn default() -> Self {
647        Self {
648            map: HashMap::default(),
649            marker: PhantomData,
650        }
651    }
652}
653
654impl<M> Default for SpecializedMaterial2dViewPipelineCache<M> {
655    fn default() -> Self {
656        Self {
657            map: HashMap::default(),
658            marker: PhantomData,
659        }
660    }
661}
662
663pub fn check_entities_needing_specialization<M>(
664    needs_specialization: Query<
665        Entity,
666        (
667            Or<(
668                Changed<Mesh2d>,
669                AssetChanged<Mesh2d>,
670                Changed<MeshMaterial2d<M>>,
671                AssetChanged<MeshMaterial2d<M>>,
672            )>,
673            With<MeshMaterial2d<M>>,
674        ),
675    >,
676    mut par_local: Local<Parallel<Vec<Entity>>>,
677    mut entities_needing_specialization: ResMut<EntitiesNeedingSpecialization<M>>,
678) where
679    M: Material2d,
680{
681    entities_needing_specialization.clear();
682
683    needs_specialization
684        .par_iter()
685        .for_each(|entity| par_local.borrow_local_mut().push(entity));
686
687    par_local.drain_into(&mut entities_needing_specialization);
688}
689
690pub fn specialize_material2d_meshes<M: Material2d>(
691    material2d_pipeline: Res<Material2dPipeline<M>>,
692    mut pipelines: ResMut<SpecializedMeshPipelines<Material2dPipeline<M>>>,
693    pipeline_cache: Res<PipelineCache>,
694    (render_meshes, render_materials): (
695        Res<RenderAssets<RenderMesh>>,
696        Res<RenderAssets<PreparedMaterial2d<M>>>,
697    ),
698    mut render_mesh_instances: ResMut<RenderMesh2dInstances>,
699    render_material_instances: Res<RenderMaterial2dInstances<M>>,
700    transparent_render_phases: Res<ViewSortedRenderPhases<Transparent2d>>,
701    opaque_render_phases: Res<ViewBinnedRenderPhases<Opaque2d>>,
702    alpha_mask_render_phases: Res<ViewBinnedRenderPhases<AlphaMask2d>>,
703    views: Query<(&MainEntity, &ExtractedView, &RenderVisibleEntities)>,
704    view_key_cache: Res<ViewKeyCache>,
705    entity_specialization_ticks: Res<EntitySpecializationTicks<M>>,
706    view_specialization_ticks: Res<ViewSpecializationTicks>,
707    ticks: SystemChangeTick,
708    mut specialized_material_pipeline_cache: ResMut<SpecializedMaterial2dPipelineCache<M>>,
709) where
710    M::Data: PartialEq + Eq + Hash + Clone,
711{
712    if render_material_instances.is_empty() {
713        return;
714    }
715
716    for (view_entity, view, visible_entities) in &views {
717        if !transparent_render_phases.contains_key(&view.retained_view_entity)
718            && !opaque_render_phases.contains_key(&view.retained_view_entity)
719            && !alpha_mask_render_phases.contains_key(&view.retained_view_entity)
720        {
721            continue;
722        }
723
724        let Some(view_key) = view_key_cache.get(view_entity) else {
725            continue;
726        };
727
728        let view_tick = view_specialization_ticks.get(view_entity).unwrap();
729        let view_specialized_material_pipeline_cache = specialized_material_pipeline_cache
730            .entry(*view_entity)
731            .or_default();
732
733        for (_, visible_entity) in visible_entities.iter::<Mesh2d>() {
734            let Some(material_asset_id) = render_material_instances.get(visible_entity) else {
735                continue;
736            };
737            let Some(mesh_instance) = render_mesh_instances.get_mut(visible_entity) else {
738                continue;
739            };
740            let Some(entity_tick) = entity_specialization_ticks.get(visible_entity) else {
741                error!("{visible_entity:?} is missing specialization tick. Spawning Meshes in PostUpdate or later is currently not fully supported.");
742                continue;
743            };
744            let last_specialized_tick = view_specialized_material_pipeline_cache
745                .get(visible_entity)
746                .map(|(tick, _)| *tick);
747            let needs_specialization = last_specialized_tick.is_none_or(|tick| {
748                view_tick.is_newer_than(tick, ticks.this_run())
749                    || entity_tick.is_newer_than(tick, ticks.this_run())
750            });
751            if !needs_specialization {
752                continue;
753            }
754            let Some(material_2d) = render_materials.get(*material_asset_id) else {
755                continue;
756            };
757            let Some(mesh) = render_meshes.get(mesh_instance.mesh_asset_id) else {
758                continue;
759            };
760            let mesh_key = *view_key
761                | Mesh2dPipelineKey::from_primitive_topology(mesh.primitive_topology())
762                | material_2d.properties.mesh_pipeline_key_bits;
763
764            let pipeline_id = pipelines.specialize(
765                &pipeline_cache,
766                &material2d_pipeline,
767                Material2dKey {
768                    mesh_key,
769                    bind_group_data: material_2d.key.clone(),
770                },
771                &mesh.layout,
772            );
773
774            let pipeline_id = match pipeline_id {
775                Ok(id) => id,
776                Err(err) => {
777                    error!("{}", err);
778                    continue;
779                }
780            };
781
782            view_specialized_material_pipeline_cache
783                .insert(*visible_entity, (ticks.this_run(), pipeline_id));
784        }
785    }
786}
787
788pub fn queue_material2d_meshes<M: Material2d>(
789    (render_meshes, render_materials): (
790        Res<RenderAssets<RenderMesh>>,
791        Res<RenderAssets<PreparedMaterial2d<M>>>,
792    ),
793    mut render_mesh_instances: ResMut<RenderMesh2dInstances>,
794    render_material_instances: Res<RenderMaterial2dInstances<M>>,
795    mut transparent_render_phases: ResMut<ViewSortedRenderPhases<Transparent2d>>,
796    mut opaque_render_phases: ResMut<ViewBinnedRenderPhases<Opaque2d>>,
797    mut alpha_mask_render_phases: ResMut<ViewBinnedRenderPhases<AlphaMask2d>>,
798    views: Query<(&MainEntity, &ExtractedView, &RenderVisibleEntities)>,
799    specialized_material_pipeline_cache: ResMut<SpecializedMaterial2dPipelineCache<M>>,
800) where
801    M::Data: PartialEq + Eq + Hash + Clone,
802{
803    if render_material_instances.is_empty() {
804        return;
805    }
806
807    for (view_entity, view, visible_entities) in &views {
808        let Some(view_specialized_material_pipeline_cache) =
809            specialized_material_pipeline_cache.get(view_entity)
810        else {
811            continue;
812        };
813
814        let Some(transparent_phase) = transparent_render_phases.get_mut(&view.retained_view_entity)
815        else {
816            continue;
817        };
818        let Some(opaque_phase) = opaque_render_phases.get_mut(&view.retained_view_entity) else {
819            continue;
820        };
821        let Some(alpha_mask_phase) = alpha_mask_render_phases.get_mut(&view.retained_view_entity)
822        else {
823            continue;
824        };
825
826        for (render_entity, visible_entity) in visible_entities.iter::<Mesh2d>() {
827            let Some((current_change_tick, pipeline_id)) = view_specialized_material_pipeline_cache
828                .get(visible_entity)
829                .map(|(current_change_tick, pipeline_id)| (*current_change_tick, *pipeline_id))
830            else {
831                continue;
832            };
833
834            // Skip the entity if it's cached in a bin and up to date.
835            if opaque_phase.validate_cached_entity(*visible_entity, current_change_tick)
836                || alpha_mask_phase.validate_cached_entity(*visible_entity, current_change_tick)
837            {
838                continue;
839            }
840
841            let Some(material_asset_id) = render_material_instances.get(visible_entity) else {
842                continue;
843            };
844            let Some(mesh_instance) = render_mesh_instances.get_mut(visible_entity) else {
845                continue;
846            };
847            let Some(material_2d) = render_materials.get(*material_asset_id) else {
848                continue;
849            };
850            let Some(mesh) = render_meshes.get(mesh_instance.mesh_asset_id) else {
851                continue;
852            };
853
854            mesh_instance.material_bind_group_id = material_2d.get_bind_group_id();
855            let mesh_z = mesh_instance.transforms.world_from_local.translation.z;
856
857            // We don't support multidraw yet for 2D meshes, so we use this
858            // custom logic to generate the `BinnedRenderPhaseType` instead of
859            // `BinnedRenderPhaseType::mesh`, which can return
860            // `BinnedRenderPhaseType::MultidrawableMesh` if the hardware
861            // supports multidraw.
862            let binned_render_phase_type = if mesh_instance.automatic_batching {
863                BinnedRenderPhaseType::BatchableMesh
864            } else {
865                BinnedRenderPhaseType::UnbatchableMesh
866            };
867
868            match material_2d.properties.alpha_mode {
869                AlphaMode2d::Opaque => {
870                    let bin_key = Opaque2dBinKey {
871                        pipeline: pipeline_id,
872                        draw_function: material_2d.properties.draw_function_id,
873                        asset_id: mesh_instance.mesh_asset_id.into(),
874                        material_bind_group_id: material_2d.get_bind_group_id().0,
875                    };
876                    opaque_phase.add(
877                        BatchSetKey2d {
878                            indexed: mesh.indexed(),
879                        },
880                        bin_key,
881                        (*render_entity, *visible_entity),
882                        InputUniformIndex::default(),
883                        binned_render_phase_type,
884                        current_change_tick,
885                    );
886                }
887                AlphaMode2d::Mask(_) => {
888                    let bin_key = AlphaMask2dBinKey {
889                        pipeline: pipeline_id,
890                        draw_function: material_2d.properties.draw_function_id,
891                        asset_id: mesh_instance.mesh_asset_id.into(),
892                        material_bind_group_id: material_2d.get_bind_group_id().0,
893                    };
894                    alpha_mask_phase.add(
895                        BatchSetKey2d {
896                            indexed: mesh.indexed(),
897                        },
898                        bin_key,
899                        (*render_entity, *visible_entity),
900                        InputUniformIndex::default(),
901                        binned_render_phase_type,
902                        current_change_tick,
903                    );
904                }
905                AlphaMode2d::Blend => {
906                    transparent_phase.add(Transparent2d {
907                        entity: (*render_entity, *visible_entity),
908                        draw_function: material_2d.properties.draw_function_id,
909                        pipeline: pipeline_id,
910                        // NOTE: Back-to-front ordering for transparent with ascending sort means far should have the
911                        // lowest sort key and getting closer should increase. As we have
912                        // -z in front of the camera, the largest distance is -far with values increasing toward the
913                        // camera. As such we can just use mesh_z as the distance
914                        sort_key: FloatOrd(mesh_z + material_2d.properties.depth_bias),
915                        // Batching is done in batch_and_prepare_render_phase
916                        batch_range: 0..1,
917                        extra_index: PhaseItemExtraIndex::None,
918                        extracted_index: usize::MAX,
919                        indexed: mesh.indexed(),
920                    });
921                }
922            }
923        }
924    }
925}
926
927#[derive(Component, Clone, Copy, Default, PartialEq, Eq, Deref, DerefMut)]
928pub struct Material2dBindGroupId(pub Option<BindGroupId>);
929
930/// Common [`Material2d`] properties, calculated for a specific material instance.
931pub struct Material2dProperties {
932    /// The [`AlphaMode2d`] of this material.
933    pub alpha_mode: AlphaMode2d,
934    /// Add a bias to the view depth of the mesh which can be used to force a specific render order
935    /// for meshes with equal depth, to avoid z-fighting.
936    /// The bias is in depth-texture units so large values may
937    pub depth_bias: f32,
938    /// The bits in the [`Mesh2dPipelineKey`] for this material.
939    ///
940    /// These are precalculated so that we can just "or" them together in
941    /// [`queue_material2d_meshes`].
942    pub mesh_pipeline_key_bits: Mesh2dPipelineKey,
943    pub draw_function_id: DrawFunctionId,
944}
945
946/// Data prepared for a [`Material2d`] instance.
947pub struct PreparedMaterial2d<T: Material2d> {
948    pub bindings: BindingResources,
949    pub bind_group: BindGroup,
950    pub key: T::Data,
951    pub properties: Material2dProperties,
952}
953
954impl<T: Material2d> PreparedMaterial2d<T> {
955    pub fn get_bind_group_id(&self) -> Material2dBindGroupId {
956        Material2dBindGroupId(Some(self.bind_group.id()))
957    }
958}
959
960impl<M: Material2d> RenderAsset for PreparedMaterial2d<M> {
961    type SourceAsset = M;
962
963    type Param = (
964        SRes<RenderDevice>,
965        SRes<Material2dPipeline<M>>,
966        SRes<DrawFunctions<Opaque2d>>,
967        SRes<DrawFunctions<AlphaMask2d>>,
968        SRes<DrawFunctions<Transparent2d>>,
969        M::Param,
970    );
971
972    fn prepare_asset(
973        material: Self::SourceAsset,
974        _: AssetId<Self::SourceAsset>,
975        (
976            render_device,
977            pipeline,
978            opaque_draw_functions,
979            alpha_mask_draw_functions,
980            transparent_draw_functions,
981            material_param,
982        ): &mut SystemParamItem<Self::Param>,
983        _: Option<&Self>,
984    ) -> Result<Self, PrepareAssetError<Self::SourceAsset>> {
985        let bind_group_data = material.bind_group_data();
986        match material.as_bind_group(&pipeline.material2d_layout, render_device, material_param) {
987            Ok(prepared) => {
988                let mut mesh_pipeline_key_bits = Mesh2dPipelineKey::empty();
989                mesh_pipeline_key_bits.insert(alpha_mode_pipeline_key(material.alpha_mode()));
990
991                let draw_function_id = match material.alpha_mode() {
992                    AlphaMode2d::Opaque => opaque_draw_functions.read().id::<DrawMaterial2d<M>>(),
993                    AlphaMode2d::Mask(_) => {
994                        alpha_mask_draw_functions.read().id::<DrawMaterial2d<M>>()
995                    }
996                    AlphaMode2d::Blend => {
997                        transparent_draw_functions.read().id::<DrawMaterial2d<M>>()
998                    }
999                };
1000
1001                Ok(PreparedMaterial2d {
1002                    bindings: prepared.bindings,
1003                    bind_group: prepared.bind_group,
1004                    key: bind_group_data,
1005                    properties: Material2dProperties {
1006                        depth_bias: material.depth_bias(),
1007                        alpha_mode: material.alpha_mode(),
1008                        mesh_pipeline_key_bits,
1009                        draw_function_id,
1010                    },
1011                })
1012            }
1013            Err(AsBindGroupError::RetryNextUpdate) => {
1014                Err(PrepareAssetError::RetryNextUpdate(material))
1015            }
1016            Err(other) => Err(PrepareAssetError::AsBindGroupError(other)),
1017        }
1018    }
1019}