custom_render_phase/
custom_render_phase.rs

1//! This example demonstrates how to write a custom phase
2//!
3//! Render phases in bevy are used whenever you need to draw a group of meshes in a specific way.
4//! For example, bevy's main pass has an opaque phase, a transparent phase for both 2d and 3d.
5//! Sometimes, you may want to only draw a subset of meshes before or after the builtin phase. In
6//! those situations you need to write your own phase.
7//!
8//! This example showcases how writing a custom phase to draw a stencil of a bevy mesh could look
9//! like. Some shortcuts have been used for simplicity.
10//!
11//! This example was made for 3d, but a 2d equivalent would be almost identical.
12
13use std::ops::Range;
14
15use bevy::camera::Viewport;
16use bevy::pbr::SetMeshViewEmptyBindGroup;
17use bevy::{
18    camera::MainPassResolutionOverride,
19    core_pipeline::core_3d::graph::{Core3d, Node3d},
20    ecs::{
21        query::QueryItem,
22        system::{lifetimeless::SRes, SystemParamItem},
23    },
24    math::FloatOrd,
25    mesh::MeshVertexBufferLayoutRef,
26    pbr::{
27        DrawMesh, MeshInputUniform, MeshPipeline, MeshPipelineKey, MeshPipelineViewLayoutKey,
28        MeshUniform, RenderMeshInstances, SetMeshBindGroup, SetMeshViewBindGroup,
29    },
30    platform::collections::HashSet,
31    prelude::*,
32    render::{
33        batching::{
34            gpu_preprocessing::{
35                batch_and_prepare_sorted_render_phase, IndirectParametersCpuMetadata,
36                UntypedPhaseIndirectParametersBuffers,
37            },
38            GetBatchData, GetFullBatchData,
39        },
40        camera::ExtractedCamera,
41        extract_component::{ExtractComponent, ExtractComponentPlugin},
42        mesh::{allocator::MeshAllocator, RenderMesh},
43        render_asset::RenderAssets,
44        render_graph::{
45            NodeRunError, RenderGraphContext, RenderGraphExt, RenderLabel, ViewNode, ViewNodeRunner,
46        },
47        render_phase::{
48            sort_phase_system, AddRenderCommand, CachedRenderPipelinePhaseItem, DrawFunctionId,
49            DrawFunctions, PhaseItem, PhaseItemExtraIndex, SetItemPipeline, SortedPhaseItem,
50            SortedRenderPhasePlugin, ViewSortedRenderPhases,
51        },
52        render_resource::{
53            CachedRenderPipelineId, ColorTargetState, ColorWrites, Face, FragmentState,
54            PipelineCache, PrimitiveState, RenderPassDescriptor, RenderPipelineDescriptor,
55            SpecializedMeshPipeline, SpecializedMeshPipelineError, SpecializedMeshPipelines,
56            TextureFormat, VertexState,
57        },
58        renderer::RenderContext,
59        sync_world::MainEntity,
60        view::{ExtractedView, RenderVisibleEntities, RetainedViewEntity, ViewTarget},
61        Extract, Render, RenderApp, RenderDebugFlags, RenderStartup, RenderSystems,
62    },
63};
64use nonmax::NonMaxU32;
65
66const SHADER_ASSET_PATH: &str = "shaders/custom_stencil.wgsl";
67
68fn main() {
69    App::new()
70        .add_plugins((DefaultPlugins, MeshStencilPhasePlugin))
71        .add_systems(Startup, setup)
72        .run();
73}
74
75fn setup(
76    mut commands: Commands,
77    mut meshes: ResMut<Assets<Mesh>>,
78    mut materials: ResMut<Assets<StandardMaterial>>,
79) {
80    // circular base
81    commands.spawn((
82        Mesh3d(meshes.add(Circle::new(4.0))),
83        MeshMaterial3d(materials.add(Color::WHITE)),
84        Transform::from_rotation(Quat::from_rotation_x(-std::f32::consts::FRAC_PI_2)),
85    ));
86    // cube
87    // This cube will be rendered by the main pass, but it will also be rendered by our custom
88    // pass. This should result in an unlit red cube
89    commands.spawn((
90        Mesh3d(meshes.add(Cuboid::new(1.0, 1.0, 1.0))),
91        MeshMaterial3d(materials.add(Color::srgb_u8(124, 144, 255))),
92        Transform::from_xyz(0.0, 0.5, 0.0),
93        // This marker component is used to identify which mesh will be used in our custom pass
94        // The circle doesn't have it so it won't be rendered in our pass
95        DrawStencil,
96    ));
97    // light
98    commands.spawn((
99        PointLight {
100            shadows_enabled: true,
101            ..default()
102        },
103        Transform::from_xyz(4.0, 8.0, 4.0),
104    ));
105    // camera
106    commands.spawn((
107        Camera3d::default(),
108        Transform::from_xyz(-2.0, 4.5, 9.0).looking_at(Vec3::ZERO, Vec3::Y),
109        // disable msaa for simplicity
110        Msaa::Off,
111    ));
112}
113
114#[derive(Component, ExtractComponent, Clone, Copy, Default)]
115struct DrawStencil;
116
117struct MeshStencilPhasePlugin;
118impl Plugin for MeshStencilPhasePlugin {
119    fn build(&self, app: &mut App) {
120        app.add_plugins((
121            ExtractComponentPlugin::<DrawStencil>::default(),
122            SortedRenderPhasePlugin::<Stencil3d, MeshPipeline>::new(RenderDebugFlags::default()),
123        ));
124        // We need to get the render app from the main app
125        let Some(render_app) = app.get_sub_app_mut(RenderApp) else {
126            return;
127        };
128        render_app
129            .init_resource::<SpecializedMeshPipelines<StencilPipeline>>()
130            .init_resource::<DrawFunctions<Stencil3d>>()
131            .add_render_command::<Stencil3d, DrawMesh3dStencil>()
132            .init_resource::<ViewSortedRenderPhases<Stencil3d>>()
133            .add_systems(RenderStartup, init_stencil_pipeline)
134            .add_systems(ExtractSchedule, extract_camera_phases)
135            .add_systems(
136                Render,
137                (
138                    queue_custom_meshes.in_set(RenderSystems::QueueMeshes),
139                    sort_phase_system::<Stencil3d>.in_set(RenderSystems::PhaseSort),
140                    batch_and_prepare_sorted_render_phase::<Stencil3d, StencilPipeline>
141                        .in_set(RenderSystems::PrepareResources),
142                ),
143            );
144
145        render_app
146            .add_render_graph_node::<ViewNodeRunner<CustomDrawNode>>(Core3d, CustomDrawPassLabel)
147            // Tell the node to run after the main pass
148            .add_render_graph_edges(Core3d, (Node3d::MainOpaquePass, CustomDrawPassLabel));
149    }
150}
151
152#[derive(Resource)]
153struct StencilPipeline {
154    /// The base mesh pipeline defined by bevy
155    ///
156    /// Since we want to draw a stencil of an existing bevy mesh we want to reuse the default
157    /// pipeline as much as possible
158    mesh_pipeline: MeshPipeline,
159    /// Stores the shader used for this pipeline directly on the pipeline.
160    /// This isn't required, it's only done like this for simplicity.
161    shader_handle: Handle<Shader>,
162}
163
164fn init_stencil_pipeline(
165    mut commands: Commands,
166    mesh_pipeline: Res<MeshPipeline>,
167    asset_server: Res<AssetServer>,
168) {
169    commands.insert_resource(StencilPipeline {
170        mesh_pipeline: mesh_pipeline.clone(),
171        shader_handle: asset_server.load(SHADER_ASSET_PATH),
172    });
173}
174
175// For more information on how SpecializedMeshPipeline work, please look at the
176// specialized_mesh_pipeline example
177impl SpecializedMeshPipeline for StencilPipeline {
178    type Key = MeshPipelineKey;
179
180    fn specialize(
181        &self,
182        key: Self::Key,
183        layout: &MeshVertexBufferLayoutRef,
184    ) -> Result<RenderPipelineDescriptor, SpecializedMeshPipelineError> {
185        // We will only use the position of the mesh in our shader so we only need to specify that
186        let mut vertex_attributes = Vec::new();
187        if layout.0.contains(Mesh::ATTRIBUTE_POSITION) {
188            // Make sure this matches the shader location
189            vertex_attributes.push(Mesh::ATTRIBUTE_POSITION.at_shader_location(0));
190        }
191        // This will automatically generate the correct `VertexBufferLayout` based on the vertex attributes
192        let vertex_buffer_layout = layout.0.get_layout(&vertex_attributes)?;
193        let view_layout = self
194            .mesh_pipeline
195            .get_view_layout(MeshPipelineViewLayoutKey::from(key));
196        Ok(RenderPipelineDescriptor {
197            label: Some("Specialized Mesh Pipeline".into()),
198            // We want to reuse the data from bevy so we use the same bind groups as the default
199            // mesh pipeline
200            layout: vec![
201                // Bind group 0 is the view uniform
202                view_layout.main_layout.clone(),
203                // Bind group 1 is empty
204                view_layout.empty_layout.clone(),
205                // Bind group 2 is the mesh uniform
206                self.mesh_pipeline.mesh_layouts.model_only.clone(),
207            ],
208            vertex: VertexState {
209                shader: self.shader_handle.clone(),
210                buffers: vec![vertex_buffer_layout],
211                ..default()
212            },
213            fragment: Some(FragmentState {
214                shader: self.shader_handle.clone(),
215                targets: vec![Some(ColorTargetState {
216                    format: TextureFormat::bevy_default(),
217                    blend: None,
218                    write_mask: ColorWrites::ALL,
219                })],
220                ..default()
221            }),
222            primitive: PrimitiveState {
223                topology: key.primitive_topology(),
224                cull_mode: Some(Face::Back),
225                ..default()
226            },
227            // It's generally recommended to specialize your pipeline for MSAA,
228            // but it's not always possible
229            ..default()
230        })
231    }
232}
233
234// We will reuse render commands already defined by bevy to draw a 3d mesh
235type DrawMesh3dStencil = (
236    SetItemPipeline,
237    // This will set the view bindings in group 0
238    SetMeshViewBindGroup<0>,
239    // This will set an empty bind group in group 1
240    SetMeshViewEmptyBindGroup<1>,
241    // This will set the mesh bindings in group 2
242    SetMeshBindGroup<2>,
243    // This will draw the mesh
244    DrawMesh,
245);
246
247// This is the data required per entity drawn in a custom phase in bevy. More specifically this is the
248// data required when using a ViewSortedRenderPhase. This would look differently if we wanted a
249// batched render phase. Sorted phases are a bit easier to implement, but a batched phase would
250// look similar.
251//
252// If you want to see how a batched phase implementation looks, you should look at the Opaque2d
253// phase.
254struct Stencil3d {
255    pub sort_key: FloatOrd,
256    pub entity: (Entity, MainEntity),
257    pub pipeline: CachedRenderPipelineId,
258    pub draw_function: DrawFunctionId,
259    pub batch_range: Range<u32>,
260    pub extra_index: PhaseItemExtraIndex,
261    /// Whether the mesh in question is indexed (uses an index buffer in
262    /// addition to its vertex buffer).
263    pub indexed: bool,
264}
265
266// For more information about writing a phase item, please look at the custom_phase_item example
267impl PhaseItem for Stencil3d {
268    #[inline]
269    fn entity(&self) -> Entity {
270        self.entity.0
271    }
272
273    #[inline]
274    fn main_entity(&self) -> MainEntity {
275        self.entity.1
276    }
277
278    #[inline]
279    fn draw_function(&self) -> DrawFunctionId {
280        self.draw_function
281    }
282
283    #[inline]
284    fn batch_range(&self) -> &Range<u32> {
285        &self.batch_range
286    }
287
288    #[inline]
289    fn batch_range_mut(&mut self) -> &mut Range<u32> {
290        &mut self.batch_range
291    }
292
293    #[inline]
294    fn extra_index(&self) -> PhaseItemExtraIndex {
295        self.extra_index.clone()
296    }
297
298    #[inline]
299    fn batch_range_and_extra_index_mut(&mut self) -> (&mut Range<u32>, &mut PhaseItemExtraIndex) {
300        (&mut self.batch_range, &mut self.extra_index)
301    }
302}
303
304impl SortedPhaseItem for Stencil3d {
305    type SortKey = FloatOrd;
306
307    #[inline]
308    fn sort_key(&self) -> Self::SortKey {
309        self.sort_key
310    }
311
312    #[inline]
313    fn sort(items: &mut [Self]) {
314        // bevy normally uses radsort instead of the std slice::sort_by_key
315        // radsort is a stable radix sort that performed better than `slice::sort_by_key` or `slice::sort_unstable_by_key`.
316        // Since it is not re-exported by bevy, we just use the std sort for the purpose of the example
317        items.sort_by_key(SortedPhaseItem::sort_key);
318    }
319
320    #[inline]
321    fn indexed(&self) -> bool {
322        self.indexed
323    }
324}
325
326impl CachedRenderPipelinePhaseItem for Stencil3d {
327    #[inline]
328    fn cached_pipeline(&self) -> CachedRenderPipelineId {
329        self.pipeline
330    }
331}
332
333impl GetBatchData for StencilPipeline {
334    type Param = (
335        SRes<RenderMeshInstances>,
336        SRes<RenderAssets<RenderMesh>>,
337        SRes<MeshAllocator>,
338    );
339    type CompareData = AssetId<Mesh>;
340    type BufferData = MeshUniform;
341
342    fn get_batch_data(
343        (mesh_instances, _render_assets, mesh_allocator): &SystemParamItem<Self::Param>,
344        (_entity, main_entity): (Entity, MainEntity),
345    ) -> Option<(Self::BufferData, Option<Self::CompareData>)> {
346        let RenderMeshInstances::CpuBuilding(ref mesh_instances) = **mesh_instances else {
347            error!(
348                "`get_batch_data` should never be called in GPU mesh uniform \
349                building mode"
350            );
351            return None;
352        };
353        let mesh_instance = mesh_instances.get(&main_entity)?;
354        let first_vertex_index =
355            match mesh_allocator.mesh_vertex_slice(&mesh_instance.mesh_asset_id) {
356                Some(mesh_vertex_slice) => mesh_vertex_slice.range.start,
357                None => 0,
358            };
359        let mesh_uniform = {
360            let mesh_transforms = &mesh_instance.transforms;
361            let (local_from_world_transpose_a, local_from_world_transpose_b) =
362                mesh_transforms.world_from_local.inverse_transpose_3x3();
363            MeshUniform {
364                world_from_local: mesh_transforms.world_from_local.to_transpose(),
365                previous_world_from_local: mesh_transforms.previous_world_from_local.to_transpose(),
366                lightmap_uv_rect: UVec2::ZERO,
367                local_from_world_transpose_a,
368                local_from_world_transpose_b,
369                flags: mesh_transforms.flags,
370                first_vertex_index,
371                current_skin_index: u32::MAX,
372                material_and_lightmap_bind_group_slot: 0,
373                tag: 0,
374                pad: 0,
375            }
376        };
377        Some((mesh_uniform, None))
378    }
379}
380
381impl GetFullBatchData for StencilPipeline {
382    type BufferInputData = MeshInputUniform;
383
384    fn get_index_and_compare_data(
385        (mesh_instances, _, _): &SystemParamItem<Self::Param>,
386        main_entity: MainEntity,
387    ) -> Option<(NonMaxU32, Option<Self::CompareData>)> {
388        // This should only be called during GPU building.
389        let RenderMeshInstances::GpuBuilding(ref mesh_instances) = **mesh_instances else {
390            error!(
391                "`get_index_and_compare_data` should never be called in CPU mesh uniform building \
392                mode"
393            );
394            return None;
395        };
396        let mesh_instance = mesh_instances.get(&main_entity)?;
397        Some((
398            mesh_instance.current_uniform_index,
399            mesh_instance
400                .should_batch()
401                .then_some(mesh_instance.mesh_asset_id),
402        ))
403    }
404
405    fn get_binned_batch_data(
406        (mesh_instances, _render_assets, mesh_allocator): &SystemParamItem<Self::Param>,
407        main_entity: MainEntity,
408    ) -> Option<Self::BufferData> {
409        let RenderMeshInstances::CpuBuilding(ref mesh_instances) = **mesh_instances else {
410            error!(
411                "`get_binned_batch_data` should never be called in GPU mesh uniform building mode"
412            );
413            return None;
414        };
415        let mesh_instance = mesh_instances.get(&main_entity)?;
416        let first_vertex_index =
417            match mesh_allocator.mesh_vertex_slice(&mesh_instance.mesh_asset_id) {
418                Some(mesh_vertex_slice) => mesh_vertex_slice.range.start,
419                None => 0,
420            };
421
422        Some(MeshUniform::new(
423            &mesh_instance.transforms,
424            first_vertex_index,
425            mesh_instance.material_bindings_index.slot,
426            None,
427            None,
428            None,
429        ))
430    }
431
432    fn write_batch_indirect_parameters_metadata(
433        indexed: bool,
434        base_output_index: u32,
435        batch_set_index: Option<NonMaxU32>,
436        indirect_parameters_buffers: &mut UntypedPhaseIndirectParametersBuffers,
437        indirect_parameters_offset: u32,
438    ) {
439        // Note that `IndirectParameters` covers both of these structures, even
440        // though they actually have distinct layouts. See the comment above that
441        // type for more information.
442        let indirect_parameters = IndirectParametersCpuMetadata {
443            base_output_index,
444            batch_set_index: match batch_set_index {
445                None => !0,
446                Some(batch_set_index) => u32::from(batch_set_index),
447            },
448        };
449
450        if indexed {
451            indirect_parameters_buffers
452                .indexed
453                .set(indirect_parameters_offset, indirect_parameters);
454        } else {
455            indirect_parameters_buffers
456                .non_indexed
457                .set(indirect_parameters_offset, indirect_parameters);
458        }
459    }
460
461    fn get_binned_index(
462        _param: &SystemParamItem<Self::Param>,
463        _query_item: MainEntity,
464    ) -> Option<NonMaxU32> {
465        None
466    }
467}
468
469// When defining a phase, we need to extract it from the main world and add it to a resource
470// that will be used by the render world. We need to give that resource all views that will use
471// that phase
472fn extract_camera_phases(
473    mut stencil_phases: ResMut<ViewSortedRenderPhases<Stencil3d>>,
474    cameras: Extract<Query<(Entity, &Camera), With<Camera3d>>>,
475    mut live_entities: Local<HashSet<RetainedViewEntity>>,
476) {
477    live_entities.clear();
478    for (main_entity, camera) in &cameras {
479        if !camera.is_active {
480            continue;
481        }
482        // This is the main camera, so we use the first subview index (0)
483        let retained_view_entity = RetainedViewEntity::new(main_entity.into(), None, 0);
484
485        stencil_phases.insert_or_clear(retained_view_entity);
486        live_entities.insert(retained_view_entity);
487    }
488
489    // Clear out all dead views.
490    stencil_phases.retain(|camera_entity, _| live_entities.contains(camera_entity));
491}
492
493// This is a very important step when writing a custom phase.
494//
495// This system determines which meshes will be added to the phase.
496fn queue_custom_meshes(
497    custom_draw_functions: Res<DrawFunctions<Stencil3d>>,
498    mut pipelines: ResMut<SpecializedMeshPipelines<StencilPipeline>>,
499    pipeline_cache: Res<PipelineCache>,
500    custom_draw_pipeline: Res<StencilPipeline>,
501    render_meshes: Res<RenderAssets<RenderMesh>>,
502    render_mesh_instances: Res<RenderMeshInstances>,
503    mut custom_render_phases: ResMut<ViewSortedRenderPhases<Stencil3d>>,
504    mut views: Query<(&ExtractedView, &RenderVisibleEntities, &Msaa)>,
505    has_marker: Query<(), With<DrawStencil>>,
506) {
507    for (view, visible_entities, msaa) in &mut views {
508        let Some(custom_phase) = custom_render_phases.get_mut(&view.retained_view_entity) else {
509            continue;
510        };
511        let draw_custom = custom_draw_functions.read().id::<DrawMesh3dStencil>();
512
513        // Create the key based on the view.
514        // In this case we only care about MSAA and HDR
515        let view_key = MeshPipelineKey::from_msaa_samples(msaa.samples())
516            | MeshPipelineKey::from_hdr(view.hdr);
517
518        let rangefinder = view.rangefinder3d();
519        // Since our phase can work on any 3d mesh we can reuse the default mesh 3d filter
520        for (render_entity, visible_entity) in visible_entities.iter::<Mesh3d>() {
521            // We only want meshes with the marker component to be queued to our phase.
522            if has_marker.get(*render_entity).is_err() {
523                continue;
524            }
525            let Some(mesh_instance) = render_mesh_instances.render_mesh_queue_data(*visible_entity)
526            else {
527                continue;
528            };
529            let Some(mesh) = render_meshes.get(mesh_instance.mesh_asset_id) else {
530                continue;
531            };
532
533            // Specialize the key for the current mesh entity
534            // For this example we only specialize based on the mesh topology
535            // but you could have more complex keys and that's where you'd need to create those keys
536            let mut mesh_key = view_key;
537            mesh_key |= MeshPipelineKey::from_primitive_topology(mesh.primitive_topology());
538
539            let pipeline_id = pipelines.specialize(
540                &pipeline_cache,
541                &custom_draw_pipeline,
542                mesh_key,
543                &mesh.layout,
544            );
545            let pipeline_id = match pipeline_id {
546                Ok(id) => id,
547                Err(err) => {
548                    error!("{}", err);
549                    continue;
550                }
551            };
552            let distance = rangefinder.distance_translation(&mesh_instance.translation);
553            // At this point we have all the data we need to create a phase item and add it to our
554            // phase
555            custom_phase.add(Stencil3d {
556                // Sort the data based on the distance to the view
557                sort_key: FloatOrd(distance),
558                entity: (*render_entity, *visible_entity),
559                pipeline: pipeline_id,
560                draw_function: draw_custom,
561                // Sorted phase items aren't batched
562                batch_range: 0..1,
563                extra_index: PhaseItemExtraIndex::None,
564                indexed: mesh.indexed(),
565            });
566        }
567    }
568}
569
570// Render label used to order our render graph node that will render our phase
571#[derive(RenderLabel, Debug, Clone, Hash, PartialEq, Eq)]
572struct CustomDrawPassLabel;
573
574#[derive(Default)]
575struct CustomDrawNode;
576impl ViewNode for CustomDrawNode {
577    type ViewQuery = (
578        &'static ExtractedCamera,
579        &'static ExtractedView,
580        &'static ViewTarget,
581        Option<&'static MainPassResolutionOverride>,
582    );
583
584    fn run<'w>(
585        &self,
586        graph: &mut RenderGraphContext,
587        render_context: &mut RenderContext<'w>,
588        (camera, view, target, resolution_override): QueryItem<'w, '_, Self::ViewQuery>,
589        world: &'w World,
590    ) -> Result<(), NodeRunError> {
591        // First, we need to get our phases resource
592        let Some(stencil_phases) = world.get_resource::<ViewSortedRenderPhases<Stencil3d>>() else {
593            return Ok(());
594        };
595
596        // Get the view entity from the graph
597        let view_entity = graph.view_entity();
598
599        // Get the phase for the current view running our node
600        let Some(stencil_phase) = stencil_phases.get(&view.retained_view_entity) else {
601            return Ok(());
602        };
603
604        // Render pass setup
605        let mut render_pass = render_context.begin_tracked_render_pass(RenderPassDescriptor {
606            label: Some("stencil pass"),
607            // For the purpose of the example, we will write directly to the view target. A real
608            // stencil pass would write to a custom texture and that texture would be used in later
609            // passes to render custom effects using it.
610            color_attachments: &[Some(target.get_color_attachment())],
611            // We don't bind any depth buffer for this pass
612            depth_stencil_attachment: None,
613            timestamp_writes: None,
614            occlusion_query_set: None,
615        });
616
617        if let Some(viewport) =
618            Viewport::from_viewport_and_override(camera.viewport.as_ref(), resolution_override)
619        {
620            render_pass.set_camera_viewport(&viewport);
621        }
622
623        // Render the phase
624        // This will execute each draw functions of each phase items queued in this phase
625        if let Err(err) = stencil_phase.render(&mut render_pass, world, view_entity) {
626            error!("Error encountered while rendering the stencil phase {err:?}");
627        }
628
629        Ok(())
630    }
631}