custom_render_phase/
custom_render_phase.rs

1//! This example demonstrates how to write a custom phase
2//!
3//! Render phases in bevy are used whenever you need to draw a group of meshes in a specific way.
4//! For example, bevy's main pass has an opaque phase, a transparent phase for both 2d and 3d.
5//! Sometimes, you may want to only draw a subset of meshes before or after the builtin phase. In
6//! those situations you need to write your own phase.
7//!
8//! This example showcases how writing a custom phase to draw a stencil of a bevy mesh could look
9//! like. Some shortcuts have been used for simplicity.
10//!
11//! This example was made for 3d, but a 2d equivalent would be almost identical.
12
13use std::ops::Range;
14
15use bevy::{
16    core_pipeline::core_3d::graph::{Core3d, Node3d},
17    ecs::{
18        query::QueryItem,
19        system::{lifetimeless::SRes, SystemParamItem},
20    },
21    math::FloatOrd,
22    pbr::{
23        DrawMesh, MeshInputUniform, MeshPipeline, MeshPipelineKey, MeshPipelineViewLayoutKey,
24        MeshUniform, RenderMeshInstances, SetMeshBindGroup, SetMeshViewBindGroup,
25    },
26    platform::collections::HashSet,
27    prelude::*,
28    render::{
29        batching::{
30            gpu_preprocessing::{
31                batch_and_prepare_sorted_render_phase, IndirectParametersCpuMetadata,
32                UntypedPhaseIndirectParametersBuffers,
33            },
34            GetBatchData, GetFullBatchData,
35        },
36        camera::ExtractedCamera,
37        extract_component::{ExtractComponent, ExtractComponentPlugin},
38        mesh::{allocator::MeshAllocator, MeshVertexBufferLayoutRef, RenderMesh},
39        render_asset::RenderAssets,
40        render_graph::{
41            NodeRunError, RenderGraphApp, RenderGraphContext, RenderLabel, ViewNode, ViewNodeRunner,
42        },
43        render_phase::{
44            sort_phase_system, AddRenderCommand, CachedRenderPipelinePhaseItem, DrawFunctionId,
45            DrawFunctions, PhaseItem, PhaseItemExtraIndex, SetItemPipeline, SortedPhaseItem,
46            SortedRenderPhasePlugin, ViewSortedRenderPhases,
47        },
48        render_resource::{
49            CachedRenderPipelineId, ColorTargetState, ColorWrites, Face, FragmentState, FrontFace,
50            MultisampleState, PipelineCache, PolygonMode, PrimitiveState, RenderPassDescriptor,
51            RenderPipelineDescriptor, SpecializedMeshPipeline, SpecializedMeshPipelineError,
52            SpecializedMeshPipelines, TextureFormat, VertexState,
53        },
54        renderer::RenderContext,
55        sync_world::MainEntity,
56        view::{ExtractedView, RenderVisibleEntities, RetainedViewEntity, ViewTarget},
57        Extract, Render, RenderApp, RenderDebugFlags, RenderSet,
58    },
59};
60use nonmax::NonMaxU32;
61
62const SHADER_ASSET_PATH: &str = "shaders/custom_stencil.wgsl";
63
64fn main() {
65    App::new()
66        .add_plugins((DefaultPlugins, MeshStencilPhasePlugin))
67        .add_systems(Startup, setup)
68        .run();
69}
70
71fn setup(
72    mut commands: Commands,
73    mut meshes: ResMut<Assets<Mesh>>,
74    mut materials: ResMut<Assets<StandardMaterial>>,
75) {
76    // circular base
77    commands.spawn((
78        Mesh3d(meshes.add(Circle::new(4.0))),
79        MeshMaterial3d(materials.add(Color::WHITE)),
80        Transform::from_rotation(Quat::from_rotation_x(-std::f32::consts::FRAC_PI_2)),
81    ));
82    // cube
83    // This cube will be rendered by the main pass, but it will also be rendered by our custom
84    // pass. This should result in an unlit red cube
85    commands.spawn((
86        Mesh3d(meshes.add(Cuboid::new(1.0, 1.0, 1.0))),
87        MeshMaterial3d(materials.add(Color::srgb_u8(124, 144, 255))),
88        Transform::from_xyz(0.0, 0.5, 0.0),
89        // This marker component is used to identify which mesh will be used in our custom pass
90        // The circle doesn't have it so it won't be rendered in our pass
91        DrawStencil,
92    ));
93    // light
94    commands.spawn((
95        PointLight {
96            shadows_enabled: true,
97            ..default()
98        },
99        Transform::from_xyz(4.0, 8.0, 4.0),
100    ));
101    // camera
102    commands.spawn((
103        Camera3d::default(),
104        Transform::from_xyz(-2.0, 4.5, 9.0).looking_at(Vec3::ZERO, Vec3::Y),
105        // disable msaa for simplicity
106        Msaa::Off,
107    ));
108}
109
110#[derive(Component, ExtractComponent, Clone, Copy, Default)]
111struct DrawStencil;
112
113struct MeshStencilPhasePlugin;
114impl Plugin for MeshStencilPhasePlugin {
115    fn build(&self, app: &mut App) {
116        app.add_plugins((
117            ExtractComponentPlugin::<DrawStencil>::default(),
118            SortedRenderPhasePlugin::<Stencil3d, MeshPipeline>::new(RenderDebugFlags::default()),
119        ));
120        // We need to get the render app from the main app
121        let Some(render_app) = app.get_sub_app_mut(RenderApp) else {
122            return;
123        };
124        render_app
125            .init_resource::<SpecializedMeshPipelines<StencilPipeline>>()
126            .init_resource::<DrawFunctions<Stencil3d>>()
127            .add_render_command::<Stencil3d, DrawMesh3dStencil>()
128            .init_resource::<ViewSortedRenderPhases<Stencil3d>>()
129            .add_systems(ExtractSchedule, extract_camera_phases)
130            .add_systems(
131                Render,
132                (
133                    queue_custom_meshes.in_set(RenderSet::QueueMeshes),
134                    sort_phase_system::<Stencil3d>.in_set(RenderSet::PhaseSort),
135                    batch_and_prepare_sorted_render_phase::<Stencil3d, StencilPipeline>
136                        .in_set(RenderSet::PrepareResources),
137                ),
138            );
139
140        render_app
141            .add_render_graph_node::<ViewNodeRunner<CustomDrawNode>>(Core3d, CustomDrawPassLabel)
142            // Tell the node to run after the main pass
143            .add_render_graph_edges(Core3d, (Node3d::MainOpaquePass, CustomDrawPassLabel));
144    }
145
146    fn finish(&self, app: &mut App) {
147        // We need to get the render app from the main app
148        let Some(render_app) = app.get_sub_app_mut(RenderApp) else {
149            return;
150        };
151        // The pipeline needs the RenderDevice to be created and it's only available once plugins
152        // are initialized
153        render_app.init_resource::<StencilPipeline>();
154    }
155}
156
157#[derive(Resource)]
158struct StencilPipeline {
159    /// The base mesh pipeline defined by bevy
160    ///
161    /// Since we want to draw a stencil of an existing bevy mesh we want to reuse the default
162    /// pipeline as much as possible
163    mesh_pipeline: MeshPipeline,
164    /// Stores the shader used for this pipeline directly on the pipeline.
165    /// This isn't required, it's only done like this for simplicity.
166    shader_handle: Handle<Shader>,
167}
168impl FromWorld for StencilPipeline {
169    fn from_world(world: &mut World) -> Self {
170        Self {
171            mesh_pipeline: MeshPipeline::from_world(world),
172            shader_handle: world.resource::<AssetServer>().load(SHADER_ASSET_PATH),
173        }
174    }
175}
176
177// For more information on how SpecializedMeshPipeline work, please look at the
178// specialized_mesh_pipeline example
179impl SpecializedMeshPipeline for StencilPipeline {
180    type Key = MeshPipelineKey;
181
182    fn specialize(
183        &self,
184        key: Self::Key,
185        layout: &MeshVertexBufferLayoutRef,
186    ) -> Result<RenderPipelineDescriptor, SpecializedMeshPipelineError> {
187        // We will only use the position of the mesh in our shader so we only need to specify that
188        let mut vertex_attributes = Vec::new();
189        if layout.0.contains(Mesh::ATTRIBUTE_POSITION) {
190            // Make sure this matches the shader location
191            vertex_attributes.push(Mesh::ATTRIBUTE_POSITION.at_shader_location(0));
192        }
193        // This will automatically generate the correct `VertexBufferLayout` based on the vertex attributes
194        let vertex_buffer_layout = layout.0.get_layout(&vertex_attributes)?;
195
196        Ok(RenderPipelineDescriptor {
197            label: Some("Specialized Mesh Pipeline".into()),
198            // We want to reuse the data from bevy so we use the same bind groups as the default
199            // mesh pipeline
200            layout: vec![
201                // Bind group 0 is the view uniform
202                self.mesh_pipeline
203                    .get_view_layout(MeshPipelineViewLayoutKey::from(key))
204                    .clone(),
205                // Bind group 1 is the mesh uniform
206                self.mesh_pipeline.mesh_layouts.model_only.clone(),
207            ],
208            push_constant_ranges: vec![],
209            vertex: VertexState {
210                shader: self.shader_handle.clone(),
211                shader_defs: vec![],
212                entry_point: "vertex".into(),
213                buffers: vec![vertex_buffer_layout],
214            },
215            fragment: Some(FragmentState {
216                shader: self.shader_handle.clone(),
217                shader_defs: vec![],
218                entry_point: "fragment".into(),
219                targets: vec![Some(ColorTargetState {
220                    format: TextureFormat::bevy_default(),
221                    blend: None,
222                    write_mask: ColorWrites::ALL,
223                })],
224            }),
225            primitive: PrimitiveState {
226                topology: key.primitive_topology(),
227                front_face: FrontFace::Ccw,
228                cull_mode: Some(Face::Back),
229                polygon_mode: PolygonMode::Fill,
230                ..default()
231            },
232            depth_stencil: None,
233            // It's generally recommended to specialize your pipeline for MSAA,
234            // but it's not always possible
235            multisample: MultisampleState::default(),
236            zero_initialize_workgroup_memory: false,
237        })
238    }
239}
240
241// We will reuse render commands already defined by bevy to draw a 3d mesh
242type DrawMesh3dStencil = (
243    SetItemPipeline,
244    // This will set the view bindings in group 0
245    SetMeshViewBindGroup<0>,
246    // This will set the mesh bindings in group 1
247    SetMeshBindGroup<1>,
248    // This will draw the mesh
249    DrawMesh,
250);
251
252// This is the data required per entity drawn in a custom phase in bevy. More specifically this is the
253// data required when using a ViewSortedRenderPhase. This would look differently if we wanted a
254// batched render phase. Sorted phases are a bit easier to implement, but a batched phase would
255// look similar.
256//
257// If you want to see how a batched phase implementation looks, you should look at the Opaque2d
258// phase.
259struct Stencil3d {
260    pub sort_key: FloatOrd,
261    pub entity: (Entity, MainEntity),
262    pub pipeline: CachedRenderPipelineId,
263    pub draw_function: DrawFunctionId,
264    pub batch_range: Range<u32>,
265    pub extra_index: PhaseItemExtraIndex,
266    /// Whether the mesh in question is indexed (uses an index buffer in
267    /// addition to its vertex buffer).
268    pub indexed: bool,
269}
270
271// For more information about writing a phase item, please look at the custom_phase_item example
272impl PhaseItem for Stencil3d {
273    #[inline]
274    fn entity(&self) -> Entity {
275        self.entity.0
276    }
277
278    #[inline]
279    fn main_entity(&self) -> MainEntity {
280        self.entity.1
281    }
282
283    #[inline]
284    fn draw_function(&self) -> DrawFunctionId {
285        self.draw_function
286    }
287
288    #[inline]
289    fn batch_range(&self) -> &Range<u32> {
290        &self.batch_range
291    }
292
293    #[inline]
294    fn batch_range_mut(&mut self) -> &mut Range<u32> {
295        &mut self.batch_range
296    }
297
298    #[inline]
299    fn extra_index(&self) -> PhaseItemExtraIndex {
300        self.extra_index.clone()
301    }
302
303    #[inline]
304    fn batch_range_and_extra_index_mut(&mut self) -> (&mut Range<u32>, &mut PhaseItemExtraIndex) {
305        (&mut self.batch_range, &mut self.extra_index)
306    }
307}
308
309impl SortedPhaseItem for Stencil3d {
310    type SortKey = FloatOrd;
311
312    #[inline]
313    fn sort_key(&self) -> Self::SortKey {
314        self.sort_key
315    }
316
317    #[inline]
318    fn sort(items: &mut [Self]) {
319        // bevy normally uses radsort instead of the std slice::sort_by_key
320        // radsort is a stable radix sort that performed better than `slice::sort_by_key` or `slice::sort_unstable_by_key`.
321        // Since it is not re-exported by bevy, we just use the std sort for the purpose of the example
322        items.sort_by_key(SortedPhaseItem::sort_key);
323    }
324
325    #[inline]
326    fn indexed(&self) -> bool {
327        self.indexed
328    }
329}
330
331impl CachedRenderPipelinePhaseItem for Stencil3d {
332    #[inline]
333    fn cached_pipeline(&self) -> CachedRenderPipelineId {
334        self.pipeline
335    }
336}
337
338impl GetBatchData for StencilPipeline {
339    type Param = (
340        SRes<RenderMeshInstances>,
341        SRes<RenderAssets<RenderMesh>>,
342        SRes<MeshAllocator>,
343    );
344    type CompareData = AssetId<Mesh>;
345    type BufferData = MeshUniform;
346
347    fn get_batch_data(
348        (mesh_instances, _render_assets, mesh_allocator): &SystemParamItem<Self::Param>,
349        (_entity, main_entity): (Entity, MainEntity),
350    ) -> Option<(Self::BufferData, Option<Self::CompareData>)> {
351        let RenderMeshInstances::CpuBuilding(ref mesh_instances) = **mesh_instances else {
352            error!(
353                "`get_batch_data` should never be called in GPU mesh uniform \
354                building mode"
355            );
356            return None;
357        };
358        let mesh_instance = mesh_instances.get(&main_entity)?;
359        let first_vertex_index =
360            match mesh_allocator.mesh_vertex_slice(&mesh_instance.mesh_asset_id) {
361                Some(mesh_vertex_slice) => mesh_vertex_slice.range.start,
362                None => 0,
363            };
364        let mesh_uniform = {
365            let mesh_transforms = &mesh_instance.transforms;
366            let (local_from_world_transpose_a, local_from_world_transpose_b) =
367                mesh_transforms.world_from_local.inverse_transpose_3x3();
368            MeshUniform {
369                world_from_local: mesh_transforms.world_from_local.to_transpose(),
370                previous_world_from_local: mesh_transforms.previous_world_from_local.to_transpose(),
371                lightmap_uv_rect: UVec2::ZERO,
372                local_from_world_transpose_a,
373                local_from_world_transpose_b,
374                flags: mesh_transforms.flags,
375                first_vertex_index,
376                current_skin_index: u32::MAX,
377                material_and_lightmap_bind_group_slot: 0,
378                tag: 0,
379                pad: 0,
380            }
381        };
382        Some((mesh_uniform, None))
383    }
384}
385impl GetFullBatchData for StencilPipeline {
386    type BufferInputData = MeshInputUniform;
387
388    fn get_index_and_compare_data(
389        (mesh_instances, _, _): &SystemParamItem<Self::Param>,
390        main_entity: MainEntity,
391    ) -> Option<(NonMaxU32, Option<Self::CompareData>)> {
392        // This should only be called during GPU building.
393        let RenderMeshInstances::GpuBuilding(ref mesh_instances) = **mesh_instances else {
394            error!(
395                "`get_index_and_compare_data` should never be called in CPU mesh uniform building \
396                mode"
397            );
398            return None;
399        };
400        let mesh_instance = mesh_instances.get(&main_entity)?;
401        Some((
402            mesh_instance.current_uniform_index,
403            mesh_instance
404                .should_batch()
405                .then_some(mesh_instance.mesh_asset_id),
406        ))
407    }
408
409    fn get_binned_batch_data(
410        (mesh_instances, _render_assets, mesh_allocator): &SystemParamItem<Self::Param>,
411        main_entity: MainEntity,
412    ) -> Option<Self::BufferData> {
413        let RenderMeshInstances::CpuBuilding(ref mesh_instances) = **mesh_instances else {
414            error!(
415                "`get_binned_batch_data` should never be called in GPU mesh uniform building mode"
416            );
417            return None;
418        };
419        let mesh_instance = mesh_instances.get(&main_entity)?;
420        let first_vertex_index =
421            match mesh_allocator.mesh_vertex_slice(&mesh_instance.mesh_asset_id) {
422                Some(mesh_vertex_slice) => mesh_vertex_slice.range.start,
423                None => 0,
424            };
425
426        Some(MeshUniform::new(
427            &mesh_instance.transforms,
428            first_vertex_index,
429            mesh_instance.material_bindings_index.slot,
430            None,
431            None,
432            None,
433        ))
434    }
435
436    fn write_batch_indirect_parameters_metadata(
437        indexed: bool,
438        base_output_index: u32,
439        batch_set_index: Option<NonMaxU32>,
440        indirect_parameters_buffers: &mut UntypedPhaseIndirectParametersBuffers,
441        indirect_parameters_offset: u32,
442    ) {
443        // Note that `IndirectParameters` covers both of these structures, even
444        // though they actually have distinct layouts. See the comment above that
445        // type for more information.
446        let indirect_parameters = IndirectParametersCpuMetadata {
447            base_output_index,
448            batch_set_index: match batch_set_index {
449                None => !0,
450                Some(batch_set_index) => u32::from(batch_set_index),
451            },
452        };
453
454        if indexed {
455            indirect_parameters_buffers
456                .indexed
457                .set(indirect_parameters_offset, indirect_parameters);
458        } else {
459            indirect_parameters_buffers
460                .non_indexed
461                .set(indirect_parameters_offset, indirect_parameters);
462        }
463    }
464
465    fn get_binned_index(
466        _param: &SystemParamItem<Self::Param>,
467        _query_item: MainEntity,
468    ) -> Option<NonMaxU32> {
469        None
470    }
471}
472
473// When defining a phase, we need to extract it from the main world and add it to a resource
474// that will be used by the render world. We need to give that resource all views that will use
475// that phase
476fn extract_camera_phases(
477    mut stencil_phases: ResMut<ViewSortedRenderPhases<Stencil3d>>,
478    cameras: Extract<Query<(Entity, &Camera), With<Camera3d>>>,
479    mut live_entities: Local<HashSet<RetainedViewEntity>>,
480) {
481    live_entities.clear();
482    for (main_entity, camera) in &cameras {
483        if !camera.is_active {
484            continue;
485        }
486        // This is the main camera, so we use the first subview index (0)
487        let retained_view_entity = RetainedViewEntity::new(main_entity.into(), None, 0);
488
489        stencil_phases.insert_or_clear(retained_view_entity);
490        live_entities.insert(retained_view_entity);
491    }
492
493    // Clear out all dead views.
494    stencil_phases.retain(|camera_entity, _| live_entities.contains(camera_entity));
495}
496
497// This is a very important step when writing a custom phase.
498//
499// This system determines which meshes will be added to the phase.
500fn queue_custom_meshes(
501    custom_draw_functions: Res<DrawFunctions<Stencil3d>>,
502    mut pipelines: ResMut<SpecializedMeshPipelines<StencilPipeline>>,
503    pipeline_cache: Res<PipelineCache>,
504    custom_draw_pipeline: Res<StencilPipeline>,
505    render_meshes: Res<RenderAssets<RenderMesh>>,
506    render_mesh_instances: Res<RenderMeshInstances>,
507    mut custom_render_phases: ResMut<ViewSortedRenderPhases<Stencil3d>>,
508    mut views: Query<(&ExtractedView, &RenderVisibleEntities, &Msaa)>,
509    has_marker: Query<(), With<DrawStencil>>,
510) {
511    for (view, visible_entities, msaa) in &mut views {
512        let Some(custom_phase) = custom_render_phases.get_mut(&view.retained_view_entity) else {
513            continue;
514        };
515        let draw_custom = custom_draw_functions.read().id::<DrawMesh3dStencil>();
516
517        // Create the key based on the view.
518        // In this case we only care about MSAA and HDR
519        let view_key = MeshPipelineKey::from_msaa_samples(msaa.samples())
520            | MeshPipelineKey::from_hdr(view.hdr);
521
522        let rangefinder = view.rangefinder3d();
523        // Since our phase can work on any 3d mesh we can reuse the default mesh 3d filter
524        for (render_entity, visible_entity) in visible_entities.iter::<Mesh3d>() {
525            // We only want meshes with the marker component to be queued to our phase.
526            if has_marker.get(*render_entity).is_err() {
527                continue;
528            }
529            let Some(mesh_instance) = render_mesh_instances.render_mesh_queue_data(*visible_entity)
530            else {
531                continue;
532            };
533            let Some(mesh) = render_meshes.get(mesh_instance.mesh_asset_id) else {
534                continue;
535            };
536
537            // Specialize the key for the current mesh entity
538            // For this example we only specialize based on the mesh topology
539            // but you could have more complex keys and that's where you'd need to create those keys
540            let mut mesh_key = view_key;
541            mesh_key |= MeshPipelineKey::from_primitive_topology(mesh.primitive_topology());
542
543            let pipeline_id = pipelines.specialize(
544                &pipeline_cache,
545                &custom_draw_pipeline,
546                mesh_key,
547                &mesh.layout,
548            );
549            let pipeline_id = match pipeline_id {
550                Ok(id) => id,
551                Err(err) => {
552                    error!("{}", err);
553                    continue;
554                }
555            };
556            let distance = rangefinder.distance_translation(&mesh_instance.translation);
557            // At this point we have all the data we need to create a phase item and add it to our
558            // phase
559            custom_phase.add(Stencil3d {
560                // Sort the data based on the distance to the view
561                sort_key: FloatOrd(distance),
562                entity: (*render_entity, *visible_entity),
563                pipeline: pipeline_id,
564                draw_function: draw_custom,
565                // Sorted phase items aren't batched
566                batch_range: 0..1,
567                extra_index: PhaseItemExtraIndex::None,
568                indexed: mesh.indexed(),
569            });
570        }
571    }
572}
573
574// Render label used to order our render graph node that will render our phase
575#[derive(RenderLabel, Debug, Clone, Hash, PartialEq, Eq)]
576struct CustomDrawPassLabel;
577
578#[derive(Default)]
579struct CustomDrawNode;
580impl ViewNode for CustomDrawNode {
581    type ViewQuery = (
582        &'static ExtractedCamera,
583        &'static ExtractedView,
584        &'static ViewTarget,
585    );
586
587    fn run<'w>(
588        &self,
589        graph: &mut RenderGraphContext,
590        render_context: &mut RenderContext<'w>,
591        (camera, view, target): QueryItem<'w, Self::ViewQuery>,
592        world: &'w World,
593    ) -> Result<(), NodeRunError> {
594        // First, we need to get our phases resource
595        let Some(stencil_phases) = world.get_resource::<ViewSortedRenderPhases<Stencil3d>>() else {
596            return Ok(());
597        };
598
599        // Get the view entity from the graph
600        let view_entity = graph.view_entity();
601
602        // Get the phase for the current view running our node
603        let Some(stencil_phase) = stencil_phases.get(&view.retained_view_entity) else {
604            return Ok(());
605        };
606
607        // Render pass setup
608        let mut render_pass = render_context.begin_tracked_render_pass(RenderPassDescriptor {
609            label: Some("stencil pass"),
610            // For the purpose of the example, we will write directly to the view target. A real
611            // stencil pass would write to a custom texture and that texture would be used in later
612            // passes to render custom effects using it.
613            color_attachments: &[Some(target.get_color_attachment())],
614            // We don't bind any depth buffer for this pass
615            depth_stencil_attachment: None,
616            timestamp_writes: None,
617            occlusion_query_set: None,
618        });
619
620        if let Some(viewport) = camera.viewport.as_ref() {
621            render_pass.set_camera_viewport(viewport);
622        }
623
624        // Render the phase
625        // This will execute each draw functions of each phase items queued in this phase
626        if let Err(err) = stencil_phase.render(&mut render_pass, world, view_entity) {
627            error!("Error encountered while rendering the stencil phase {err:?}");
628        }
629
630        Ok(())
631    }
632}