Skip to main content

bevy_ui_render/
ui_material_pipeline.rs

1use crate::ui_material::{MaterialNode, UiMaterial, UiMaterialKey};
2use crate::*;
3use bevy_asset::*;
4use bevy_ecs::{
5    prelude::{Component, With},
6    query::ROQueryItem,
7    system::{
8        lifetimeless::{Read, SRes},
9        *,
10    },
11};
12use bevy_image::BevyDefault as _;
13use bevy_math::{Affine2, FloatOrd, Rect, Vec2};
14use bevy_mesh::VertexBufferLayout;
15use bevy_render::{
16    globals::{GlobalsBuffer, GlobalsUniform},
17    render_asset::{PrepareAssetError, RenderAsset, RenderAssetPlugin, RenderAssets},
18    render_phase::*,
19    render_resource::{binding_types::uniform_buffer, *},
20    renderer::{RenderDevice, RenderQueue},
21    sync_world::{MainEntity, TemporaryRenderEntity},
22    view::*,
23    Extract, ExtractSchedule, Render, RenderSystems,
24};
25use bevy_render::{RenderApp, RenderStartup};
26use bevy_shader::{load_shader_library, Shader, ShaderRef};
27use bevy_sprite::BorderRect;
28use bevy_utils::default;
29use bytemuck::{Pod, Zeroable};
30use core::{hash::Hash, marker::PhantomData, ops::Range};
31
32/// Adds the necessary ECS resources and render logic to enable rendering entities using the given
33/// [`UiMaterial`] asset type (which includes [`UiMaterial`] types).
34pub struct UiMaterialPlugin<M: UiMaterial>(PhantomData<M>);
35
36impl<M: UiMaterial> Default for UiMaterialPlugin<M> {
37    fn default() -> Self {
38        Self(Default::default())
39    }
40}
41
42impl<M: UiMaterial> Plugin for UiMaterialPlugin<M>
43where
44    M::Data: PartialEq + Eq + Hash + Clone,
45{
46    fn build(&self, app: &mut App) {
47        load_shader_library!(app, "ui_vertex_output.wgsl");
48
49        embedded_asset!(app, "ui_material.wgsl");
50
51        app.init_asset::<M>()
52            .register_type::<MaterialNode<M>>()
53            .add_plugins(RenderAssetPlugin::<PreparedUiMaterial<M>>::default());
54
55        if let Some(render_app) = app.get_sub_app_mut(RenderApp) {
56            render_app
57                .add_render_command::<TransparentUi, DrawUiMaterial<M>>()
58                .init_resource::<ExtractedUiMaterialNodes<M>>()
59                .init_resource::<UiMaterialMeta<M>>()
60                .init_resource::<SpecializedRenderPipelines<UiMaterialPipeline<M>>>()
61                .add_systems(RenderStartup, init_ui_material_pipeline::<M>)
62                .add_systems(
63                    ExtractSchedule,
64                    extract_ui_material_nodes::<M>.in_set(RenderUiSystems::ExtractBackgrounds),
65                )
66                .add_systems(
67                    Render,
68                    (
69                        queue_ui_material_nodes::<M>.in_set(RenderSystems::Queue),
70                        prepare_uimaterial_nodes::<M>.in_set(RenderSystems::PrepareBindGroups),
71                    ),
72                );
73        }
74    }
75}
76
77#[derive(Resource)]
78pub struct UiMaterialMeta<M: UiMaterial> {
79    vertices: RawBufferVec<UiMaterialVertex>,
80    view_bind_group: Option<BindGroup>,
81    marker: PhantomData<M>,
82}
83
84impl<M: UiMaterial> Default for UiMaterialMeta<M> {
85    fn default() -> Self {
86        Self {
87            vertices: RawBufferVec::new(BufferUsages::VERTEX),
88            view_bind_group: Default::default(),
89            marker: PhantomData,
90        }
91    }
92}
93
94#[repr(C)]
95#[derive(Copy, Clone, Pod, Zeroable)]
96pub struct UiMaterialVertex {
97    pub position: [f32; 3],
98    pub uv: [f32; 2],
99    pub size: [f32; 2],
100    pub border: [f32; 4],
101    pub radius: [f32; 4],
102}
103
104// in this [`UiMaterialPipeline`] there is (currently) no batching going on.
105// Therefore the [`UiMaterialBatch`] is more akin to a draw call.
106#[derive(Component)]
107pub struct UiMaterialBatch<M: UiMaterial> {
108    /// The range of vertices inside the [`UiMaterialMeta`]
109    pub range: Range<u32>,
110    pub material: AssetId<M>,
111}
112
113/// Render pipeline data for a given [`UiMaterial`]
114#[derive(Resource)]
115pub struct UiMaterialPipeline<M: UiMaterial> {
116    pub ui_layout: BindGroupLayoutDescriptor,
117    pub view_layout: BindGroupLayoutDescriptor,
118    pub vertex_shader: Handle<Shader>,
119    pub fragment_shader: Handle<Shader>,
120    marker: PhantomData<M>,
121}
122
123impl<M: UiMaterial> SpecializedRenderPipeline for UiMaterialPipeline<M>
124where
125    M::Data: PartialEq + Eq + Hash + Clone,
126{
127    type Key = UiMaterialKey<M>;
128
129    fn specialize(&self, key: Self::Key) -> RenderPipelineDescriptor {
130        let vertex_layout = VertexBufferLayout::from_vertex_formats(
131            VertexStepMode::Vertex,
132            vec![
133                // position
134                VertexFormat::Float32x3,
135                // uv
136                VertexFormat::Float32x2,
137                // size
138                VertexFormat::Float32x2,
139                // border widths
140                VertexFormat::Float32x4,
141                // border radius
142                VertexFormat::Float32x4,
143            ],
144        );
145        let shader_defs = Vec::new();
146
147        let mut descriptor = RenderPipelineDescriptor {
148            vertex: VertexState {
149                shader: self.vertex_shader.clone(),
150                shader_defs: shader_defs.clone(),
151                buffers: vec![vertex_layout],
152                ..default()
153            },
154            fragment: Some(FragmentState {
155                shader: self.fragment_shader.clone(),
156                shader_defs,
157                targets: vec![Some(ColorTargetState {
158                    format: if key.hdr {
159                        ViewTarget::TEXTURE_FORMAT_HDR
160                    } else {
161                        TextureFormat::bevy_default()
162                    },
163                    blend: Some(BlendState::ALPHA_BLENDING),
164                    write_mask: ColorWrites::ALL,
165                })],
166                ..default()
167            }),
168            label: Some("ui_material_pipeline".into()),
169            ..default()
170        };
171
172        descriptor.layout = vec![self.view_layout.clone(), self.ui_layout.clone()];
173
174        M::specialize(&mut descriptor, key);
175
176        descriptor
177    }
178}
179
180pub fn init_ui_material_pipeline<M: UiMaterial>(
181    mut commands: Commands,
182    asset_server: Res<AssetServer>,
183    render_device: Res<RenderDevice>,
184) {
185    let ui_layout = M::bind_group_layout_descriptor(&render_device);
186
187    let view_layout = BindGroupLayoutDescriptor::new(
188        "ui_view_layout",
189        &BindGroupLayoutEntries::sequential(
190            ShaderStages::VERTEX_FRAGMENT,
191            (
192                uniform_buffer::<ViewUniform>(true),
193                uniform_buffer::<GlobalsUniform>(false),
194            ),
195        ),
196    );
197
198    let load_default = || load_embedded_asset!(asset_server.as_ref(), "ui_material.wgsl");
199
200    commands.insert_resource(UiMaterialPipeline::<M> {
201        ui_layout,
202        view_layout,
203        vertex_shader: match M::vertex_shader() {
204            ShaderRef::Default => load_default(),
205            ShaderRef::Handle(handle) => handle,
206            ShaderRef::Path(path) => asset_server.load(path),
207        },
208        fragment_shader: match M::fragment_shader() {
209            ShaderRef::Default => load_default(),
210            ShaderRef::Handle(handle) => handle,
211            ShaderRef::Path(path) => asset_server.load(path),
212        },
213        marker: PhantomData,
214    });
215}
216
217pub type DrawUiMaterial<M> = (
218    SetItemPipeline,
219    SetMatUiViewBindGroup<M, 0>,
220    SetUiMaterialBindGroup<M, 1>,
221    DrawUiMaterialNode<M>,
222);
223
224pub struct SetMatUiViewBindGroup<M: UiMaterial, const I: usize>(PhantomData<M>);
225impl<P: PhaseItem, M: UiMaterial, const I: usize> RenderCommand<P> for SetMatUiViewBindGroup<M, I> {
226    type Param = SRes<UiMaterialMeta<M>>;
227    type ViewQuery = Read<ViewUniformOffset>;
228    type ItemQuery = ();
229
230    fn render<'w>(
231        _item: &P,
232        view_uniform: &'w ViewUniformOffset,
233        _entity: Option<()>,
234        ui_meta: SystemParamItem<'w, '_, Self::Param>,
235        pass: &mut TrackedRenderPass<'w>,
236    ) -> RenderCommandResult {
237        pass.set_bind_group(
238            I,
239            ui_meta.into_inner().view_bind_group.as_ref().unwrap(),
240            &[view_uniform.offset],
241        );
242        RenderCommandResult::Success
243    }
244}
245
246pub struct SetUiMaterialBindGroup<M: UiMaterial, const I: usize>(PhantomData<M>);
247impl<P: PhaseItem, M: UiMaterial, const I: usize> RenderCommand<P>
248    for SetUiMaterialBindGroup<M, I>
249{
250    type Param = SRes<RenderAssets<PreparedUiMaterial<M>>>;
251    type ViewQuery = ();
252    type ItemQuery = Read<UiMaterialBatch<M>>;
253
254    fn render<'w>(
255        _item: &P,
256        _view: (),
257        material_handle: Option<ROQueryItem<'_, '_, Self::ItemQuery>>,
258        materials: SystemParamItem<'w, '_, Self::Param>,
259        pass: &mut TrackedRenderPass<'w>,
260    ) -> RenderCommandResult {
261        let Some(material_handle) = material_handle else {
262            return RenderCommandResult::Skip;
263        };
264        let Some(material) = materials.into_inner().get(material_handle.material) else {
265            return RenderCommandResult::Skip;
266        };
267        pass.set_bind_group(I, &material.bind_group, &[]);
268        RenderCommandResult::Success
269    }
270}
271
272pub struct DrawUiMaterialNode<M>(PhantomData<M>);
273impl<P: PhaseItem, M: UiMaterial> RenderCommand<P> for DrawUiMaterialNode<M> {
274    type Param = SRes<UiMaterialMeta<M>>;
275    type ViewQuery = ();
276    type ItemQuery = Read<UiMaterialBatch<M>>;
277
278    #[inline]
279    fn render<'w>(
280        _item: &P,
281        _view: (),
282        batch: Option<&'w UiMaterialBatch<M>>,
283        ui_meta: SystemParamItem<'w, '_, Self::Param>,
284        pass: &mut TrackedRenderPass<'w>,
285    ) -> RenderCommandResult {
286        let Some(batch) = batch else {
287            return RenderCommandResult::Skip;
288        };
289
290        pass.set_vertex_buffer(0, ui_meta.into_inner().vertices.buffer().unwrap().slice(..));
291        pass.draw(batch.range.clone(), 0..1);
292        RenderCommandResult::Success
293    }
294}
295
296pub struct ExtractedUiMaterialNode<M: UiMaterial> {
297    pub stack_index: u32,
298    pub transform: Affine2,
299    pub rect: Rect,
300    pub border: BorderRect,
301    pub border_radius: [f32; 4],
302    pub material: AssetId<M>,
303    pub clip: Option<Rect>,
304    // Camera to render this UI node to. By the time it is extracted,
305    // it is defaulted to a single camera if only one exists.
306    // Nodes with ambiguous camera will be ignored.
307    pub extracted_camera_entity: Entity,
308    pub main_entity: MainEntity,
309    pub render_entity: Entity,
310}
311
312#[derive(Resource)]
313pub struct ExtractedUiMaterialNodes<M: UiMaterial> {
314    pub uinodes: Vec<ExtractedUiMaterialNode<M>>,
315}
316
317impl<M: UiMaterial> Default for ExtractedUiMaterialNodes<M> {
318    fn default() -> Self {
319        Self {
320            uinodes: Default::default(),
321        }
322    }
323}
324
325pub fn extract_ui_material_nodes<M: UiMaterial>(
326    mut commands: Commands,
327    mut extracted_uinodes: ResMut<ExtractedUiMaterialNodes<M>>,
328    materials: Extract<Res<Assets<M>>>,
329    uinode_query: Extract<
330        Query<(
331            Entity,
332            &ComputedNode,
333            &UiGlobalTransform,
334            &MaterialNode<M>,
335            &InheritedVisibility,
336            Option<&CalculatedClip>,
337            &ComputedUiTargetCamera,
338        )>,
339    >,
340    camera_map: Extract<UiCameraMap>,
341) {
342    let mut camera_mapper = camera_map.get_mapper();
343
344    for (entity, computed_node, transform, handle, inherited_visibility, clip, camera) in
345        uinode_query.iter()
346    {
347        // skip invisible nodes
348        if !inherited_visibility.get() || computed_node.is_empty() {
349            continue;
350        }
351
352        // Skip loading materials
353        if !materials.contains(handle) {
354            continue;
355        }
356
357        let Some(extracted_camera_entity) = camera_mapper.map(camera) else {
358            continue;
359        };
360
361        extracted_uinodes.uinodes.push(ExtractedUiMaterialNode {
362            render_entity: commands.spawn(TemporaryRenderEntity).id(),
363            stack_index: computed_node.stack_index,
364            transform: transform.into(),
365            material: handle.id(),
366            rect: Rect {
367                min: Vec2::ZERO,
368                max: computed_node.size(),
369            },
370            border: computed_node.border(),
371            border_radius: computed_node.border_radius().into(),
372            clip: clip.map(|clip| clip.clip),
373            extracted_camera_entity,
374            main_entity: entity.into(),
375        });
376    }
377}
378
379pub fn prepare_uimaterial_nodes<M: UiMaterial>(
380    mut commands: Commands,
381    render_device: Res<RenderDevice>,
382    render_queue: Res<RenderQueue>,
383    pipeline_cache: Res<PipelineCache>,
384    mut ui_meta: ResMut<UiMaterialMeta<M>>,
385    mut extracted_uinodes: ResMut<ExtractedUiMaterialNodes<M>>,
386    view_uniforms: Res<ViewUniforms>,
387    globals_buffer: Res<GlobalsBuffer>,
388    ui_material_pipeline: Res<UiMaterialPipeline<M>>,
389    mut phases: ResMut<ViewSortedRenderPhases<TransparentUi>>,
390    mut previous_len: Local<usize>,
391) {
392    if let (Some(view_binding), Some(globals_binding)) = (
393        view_uniforms.uniforms.binding(),
394        globals_buffer.buffer.binding(),
395    ) {
396        let mut batches: Vec<(Entity, UiMaterialBatch<M>)> = Vec::with_capacity(*previous_len);
397
398        ui_meta.vertices.clear();
399        ui_meta.view_bind_group = Some(render_device.create_bind_group(
400            "ui_material_view_bind_group",
401            &pipeline_cache.get_bind_group_layout(&ui_material_pipeline.view_layout),
402            &BindGroupEntries::sequential((view_binding, globals_binding)),
403        ));
404        let mut index = 0;
405
406        for ui_phase in phases.values_mut() {
407            let mut batch_item_index = 0;
408            let mut batch_shader_handle = AssetId::invalid();
409
410            for item_index in 0..ui_phase.items.len() {
411                let item = &mut ui_phase.items[item_index];
412                if let Some(extracted_uinode) = extracted_uinodes
413                    .uinodes
414                    .get(item.index)
415                    .filter(|n| item.entity() == n.render_entity)
416                {
417                    let mut existing_batch = batches
418                        .last_mut()
419                        .filter(|_| batch_shader_handle == extracted_uinode.material);
420
421                    if existing_batch.is_none() {
422                        batch_item_index = item_index;
423                        batch_shader_handle = extracted_uinode.material;
424
425                        let new_batch = UiMaterialBatch {
426                            range: index..index,
427                            material: extracted_uinode.material,
428                        };
429
430                        batches.push((item.entity(), new_batch));
431
432                        existing_batch = batches.last_mut();
433                    }
434
435                    let uinode_rect = extracted_uinode.rect;
436
437                    let rect_size = uinode_rect.size();
438
439                    let positions = QUAD_VERTEX_POSITIONS.map(|pos| {
440                        extracted_uinode
441                            .transform
442                            .transform_point2(pos * rect_size)
443                            .extend(1.0)
444                    });
445
446                    let positions_diff = if let Some(clip) = extracted_uinode.clip {
447                        [
448                            Vec2::new(
449                                f32::max(clip.min.x - positions[0].x, 0.),
450                                f32::max(clip.min.y - positions[0].y, 0.),
451                            ),
452                            Vec2::new(
453                                f32::min(clip.max.x - positions[1].x, 0.),
454                                f32::max(clip.min.y - positions[1].y, 0.),
455                            ),
456                            Vec2::new(
457                                f32::min(clip.max.x - positions[2].x, 0.),
458                                f32::min(clip.max.y - positions[2].y, 0.),
459                            ),
460                            Vec2::new(
461                                f32::max(clip.min.x - positions[3].x, 0.),
462                                f32::min(clip.max.y - positions[3].y, 0.),
463                            ),
464                        ]
465                    } else {
466                        [Vec2::ZERO; 4]
467                    };
468
469                    let positions_clipped = [
470                        positions[0] + positions_diff[0].extend(0.),
471                        positions[1] + positions_diff[1].extend(0.),
472                        positions[2] + positions_diff[2].extend(0.),
473                        positions[3] + positions_diff[3].extend(0.),
474                    ];
475
476                    let transformed_rect_size = extracted_uinode
477                        .transform
478                        .transform_vector2(rect_size)
479                        .abs();
480
481                    // Don't try to cull nodes that have a rotation
482                    // In a rotation around the Z-axis, this value is 0.0 for an angle of 0.0 or π
483                    // In those two cases, the culling check can proceed normally as corners will be on
484                    // horizontal / vertical lines
485                    // For all other angles, bypass the culling check
486                    // This does not properly handles all rotations on all axis
487                    if extracted_uinode.transform.x_axis[1] == 0.0 {
488                        // Cull nodes that are completely clipped
489                        if positions_diff[0].x - positions_diff[1].x >= transformed_rect_size.x
490                            || positions_diff[1].y - positions_diff[2].y >= transformed_rect_size.y
491                        {
492                            continue;
493                        }
494                    }
495                    let uvs = [
496                        Vec2::new(
497                            uinode_rect.min.x + positions_diff[0].x,
498                            uinode_rect.min.y + positions_diff[0].y,
499                        ),
500                        Vec2::new(
501                            uinode_rect.max.x + positions_diff[1].x,
502                            uinode_rect.min.y + positions_diff[1].y,
503                        ),
504                        Vec2::new(
505                            uinode_rect.max.x + positions_diff[2].x,
506                            uinode_rect.max.y + positions_diff[2].y,
507                        ),
508                        Vec2::new(
509                            uinode_rect.min.x + positions_diff[3].x,
510                            uinode_rect.max.y + positions_diff[3].y,
511                        ),
512                    ]
513                    .map(|pos| pos / uinode_rect.max);
514
515                    for i in QUAD_INDICES {
516                        ui_meta.vertices.push(UiMaterialVertex {
517                            position: positions_clipped[i].into(),
518                            uv: uvs[i].into(),
519                            size: extracted_uinode.rect.size().into(),
520                            radius: extracted_uinode.border_radius,
521                            border: [
522                                extracted_uinode.border.min_inset.x,
523                                extracted_uinode.border.min_inset.y,
524                                extracted_uinode.border.max_inset.x,
525                                extracted_uinode.border.max_inset.y,
526                            ],
527                        });
528                    }
529
530                    index += QUAD_INDICES.len() as u32;
531                    existing_batch.unwrap().1.range.end = index;
532                    ui_phase.items[batch_item_index].batch_range_mut().end += 1;
533                } else {
534                    batch_shader_handle = AssetId::invalid();
535                }
536            }
537        }
538        ui_meta.vertices.write_buffer(&render_device, &render_queue);
539        *previous_len = batches.len();
540        commands.try_insert_batch(batches);
541    }
542    extracted_uinodes.uinodes.clear();
543}
544
545pub struct PreparedUiMaterial<T: UiMaterial> {
546    pub bindings: BindingResources,
547    pub bind_group: BindGroup,
548    pub key: T::Data,
549}
550
551impl<M: UiMaterial> RenderAsset for PreparedUiMaterial<M> {
552    type SourceAsset = M;
553
554    type Param = (
555        SRes<RenderDevice>,
556        SRes<PipelineCache>,
557        SRes<UiMaterialPipeline<M>>,
558        M::Param,
559    );
560
561    fn prepare_asset(
562        material: Self::SourceAsset,
563        _: AssetId<Self::SourceAsset>,
564        (render_device, pipeline_cache, pipeline, material_param): &mut SystemParamItem<
565            Self::Param,
566        >,
567        _: Option<&Self>,
568    ) -> Result<Self, PrepareAssetError<Self::SourceAsset>> {
569        let bind_group_data = material.bind_group_data();
570        match material.as_bind_group(
571            &pipeline.ui_layout.clone(),
572            render_device,
573            pipeline_cache,
574            material_param,
575        ) {
576            Ok(prepared) => Ok(PreparedUiMaterial {
577                bindings: prepared.bindings,
578                bind_group: prepared.bind_group,
579                key: bind_group_data,
580            }),
581            Err(AsBindGroupError::RetryNextUpdate) => {
582                Err(PrepareAssetError::RetryNextUpdate(material))
583            }
584            Err(other) => Err(PrepareAssetError::AsBindGroupError(other)),
585        }
586    }
587}
588
589pub fn queue_ui_material_nodes<M: UiMaterial>(
590    extracted_uinodes: Res<ExtractedUiMaterialNodes<M>>,
591    draw_functions: Res<DrawFunctions<TransparentUi>>,
592    ui_material_pipeline: Res<UiMaterialPipeline<M>>,
593    mut pipelines: ResMut<SpecializedRenderPipelines<UiMaterialPipeline<M>>>,
594    pipeline_cache: Res<PipelineCache>,
595    render_materials: Res<RenderAssets<PreparedUiMaterial<M>>>,
596    mut transparent_render_phases: ResMut<ViewSortedRenderPhases<TransparentUi>>,
597    mut render_views: Query<&UiCameraView, With<ExtractedView>>,
598    camera_views: Query<&ExtractedView>,
599) where
600    M::Data: PartialEq + Eq + Hash + Clone,
601{
602    let draw_function = draw_functions.read().id::<DrawUiMaterial<M>>();
603
604    for (index, extracted_uinode) in extracted_uinodes.uinodes.iter().enumerate() {
605        let Some(material) = render_materials.get(extracted_uinode.material) else {
606            continue;
607        };
608
609        let Ok(default_camera_view) =
610            render_views.get_mut(extracted_uinode.extracted_camera_entity)
611        else {
612            continue;
613        };
614
615        let Ok(view) = camera_views.get(default_camera_view.0) else {
616            continue;
617        };
618
619        let Some(transparent_phase) = transparent_render_phases.get_mut(&view.retained_view_entity)
620        else {
621            continue;
622        };
623
624        let pipeline = pipelines.specialize(
625            &pipeline_cache,
626            &ui_material_pipeline,
627            UiMaterialKey {
628                hdr: view.hdr,
629                bind_group_data: material.key.clone(),
630            },
631        );
632        if transparent_phase.items.capacity() < extracted_uinodes.uinodes.len() {
633            transparent_phase.items.reserve_exact(
634                extracted_uinodes.uinodes.len() - transparent_phase.items.capacity(),
635            );
636        }
637        transparent_phase.add(TransparentUi {
638            draw_function,
639            pipeline,
640            entity: (extracted_uinode.render_entity, extracted_uinode.main_entity),
641            sort_key: FloatOrd(extracted_uinode.stack_index as f32 + M::stack_z_offset()),
642            batch_range: 0..0,
643            extra_index: PhaseItemExtraIndex::None,
644            index,
645            indexed: false,
646        });
647    }
648}