1use std::{iter::FilterMap, ops::Index};
4
5use bevy_ecs::{
6 entity::EntityHashMap,
7 prelude::*,
8 query::QueryManyIter,
9 system::{
10 SystemParam, SystemState,
11 lifetimeless::{Read, Write},
12 },
13};
14use bevy_hierarchy::prelude::*;
15use bevy_math::prelude::*;
16use bevy_reflect::prelude::*;
17use bevy_transform::prelude::*;
18use taffy::{
19 AvailableSpace, Cache, CacheTree, Layout, LayoutBlockContainer, LayoutFlexboxContainer, LayoutInput, LayoutOutput,
20 LayoutPartialTree, NodeId, PrintTree, RoundTree, RunMode, Size, TraversePartialTree, TraverseTree, compute_block_layout,
21 compute_cached_layout, compute_flexbox_layout, compute_hidden_layout, compute_leaf_layout, compute_root_layout,
22 round_layout,
23};
24
25use crate::{
26 measure::{ContentSize, MeasureId, Measurements, Measurer},
27 root::{UiRootTrns, UiUnrounded},
28 style::{Display, Ui, WithCtx},
29};
30
31#[derive(Component, Copy, Clone, Default)]
33#[require(IntermediateUi, UiCache)]
34pub struct ComputedUi {
35 pub order: u32,
40 location: Vec2,
42 pub size: Vec2,
44 pub content_size: Vec2,
48 pub scrollbar_size: Vec2,
51 pub border: Border,
53 pub padding: Border,
55 pub margin: Border,
57}
58
59#[derive(Component, Copy, Clone, Default)]
60pub(crate) struct IntermediateUi(Layout);
61
62#[derive(Reflect, Copy, Clone, Default)]
64#[reflect(Default)]
65pub struct Border {
66 pub left: f32,
68 pub right: f32,
70 pub top: f32,
72 pub bottom: f32,
74}
75
76impl From<taffy::Rect<f32>> for Border {
77 #[inline]
78 fn from(
79 taffy::Rect {
80 left,
81 right,
82 top,
83 bottom,
84 }: taffy::Rect<f32>,
85 ) -> Self {
86 Self {
87 left,
88 right,
89 top,
90 bottom,
91 }
92 }
93}
94
95#[derive(Component, Default)]
96pub(crate) struct UiCache(Cache);
97impl UiCache {
98 #[inline]
99 pub fn clear(&mut self) {
100 self.0.clear()
101 }
102}
103
104#[derive(SystemParam)]
107pub struct UiCaches<'w, 's>(Query<'w, 's, (Write<UiCache>, Option<Read<Parent>>)>);
108
109impl UiCaches<'_, '_> {
110 #[inline]
112 pub fn invalidate(&mut self, mut e: Entity) {
113 loop {
114 let Ok((mut cache, parent)) = self.0.get_mut(e) else { break };
115 cache.clear();
116
117 if let Some(parent) = parent { e = **parent } else { break }
118 }
119 }
120}
121
122pub(crate) struct UiTree<'w, 's, M> {
123 measurements: M,
124 viewport_size: Vec2,
125 ui_query: Query<'w, 's, (Entity, Has<UiRootTrns>, Read<Ui>, Option<Read<ContentSize>>)>,
126 children_query: Query<'w, 's, Read<Children>>,
127 intermediate_query: Query<'w, 's, Write<IntermediateUi>>,
128 cache_query: Query<'w, 's, Write<UiCache>>,
129 outputs: &'s mut EntityHashMap<Layout>,
130}
131
132impl<M> TraverseTree for UiTree<'_, '_, M> {}
133
134impl<M> TraversePartialTree for UiTree<'_, '_, M> {
135 type ChildIter<'a>
136 = FilterMap<
137 QueryManyIter<
138 'a,
139 'a,
140 (Entity, Has<UiRootTrns>, Read<Ui>, Option<Read<ContentSize>>),
141 (),
142 std::slice::Iter<'a, Entity>,
143 >,
144 fn((Entity, bool, &Ui, Option<&ContentSize>)) -> Option<NodeId>,
145 >
146 where
147 Self: 'a;
148
149 #[inline]
150 fn child_ids(&self, parent_node_id: NodeId) -> Self::ChildIter<'_> {
151 let children = self
152 .children_query
153 .get(Entity::from_bits(parent_node_id.into()))
154 .map(|children| &**children)
155 .unwrap_or(&[])
156 .iter();
157
158 self.ui_query
159 .iter_many(children)
160 .filter_map(|(e, is_root, ..)| (!is_root).then_some(NodeId::from(e.to_bits())))
161 }
162
163 #[inline]
164 fn child_count(&self, parent_node_id: NodeId) -> usize {
165 self.child_ids(parent_node_id).count()
166 }
167
168 #[inline]
169 fn get_child_id(&self, parent_node_id: NodeId, child_index: usize) -> NodeId {
170 self.child_ids(parent_node_id).nth(child_index).unwrap()
171 }
172}
173
174impl<M: Index<MeasureId, Output = dyn Measurer>> LayoutPartialTree for UiTree<'_, '_, M> {
175 type CoreContainerStyle<'a>
176 = WithCtx<&'a Ui>
177 where
178 Self: 'a;
179
180 #[inline]
181 fn get_core_container_style(&self, node_id: NodeId) -> Self::CoreContainerStyle<'_> {
182 let e = Entity::from_bits(node_id.into());
183 WithCtx {
184 width: self.viewport_size.x,
185 height: self.viewport_size.y,
186 item: self.ui_query.get(e).unwrap().2,
187 }
188 }
189
190 #[inline]
191 fn set_unrounded_layout(&mut self, node_id: NodeId, layout: &Layout) {
192 self.intermediate_query.get_mut(Entity::from_bits(node_id.into())).unwrap().0 = *layout
193 }
194
195 #[inline]
196 fn compute_child_layout(&mut self, node_id: NodeId, inputs: LayoutInput) -> LayoutOutput {
197 compute_cached_layout(self, node_id, inputs, |tree, node_id, inputs| {
198 let e = Entity::from_bits(node_id.into());
199 let (.., node, measure) = tree.ui_query.get(e).unwrap();
200 let has_children = tree.child_count(node_id) != 0;
201
202 match (node.display, has_children) {
203 (Display::Flexbox, true) => compute_flexbox_layout(tree, node_id, inputs),
204 (Display::Block, true) => compute_block_layout(tree, node_id, inputs),
205 (Display::None, _) => compute_hidden_layout(tree, node_id),
206 (_, false) => compute_leaf_layout(
207 inputs,
208 &WithCtx {
209 width: tree.viewport_size.x,
210 height: tree.viewport_size.y,
211 item: node,
212 },
213 |known_size, available_space| {
214 if let Some(measure) = measure.and_then(|id| match id.get() {
215 MeasureId::INVALID => None,
216 id => Some(id),
217 }) {
218 let Vec2 { x: width, y: height } = tree.measurements[measure].measure(
219 (known_size.width, known_size.height),
220 (available_space.width.into(), available_space.height.into()),
221 e,
222 );
223
224 Size { width, height }
225 } else {
226 Size::ZERO
227 }
228 },
229 ),
230 }
231 })
232 }
233}
234
235impl<M> RoundTree for UiTree<'_, '_, M> {
236 #[inline]
237 fn get_unrounded_layout(&self, node_id: NodeId) -> &Layout {
238 &self.intermediate_query.get(Entity::from_bits(node_id.into())).unwrap().0
239 }
240
241 #[inline]
242 fn set_final_layout(&mut self, node_id: NodeId, layout: &Layout) {
243 let e = Entity::from_bits(node_id.into());
244 self.outputs.insert(e, *layout);
245 }
246}
247
248impl<M> PrintTree for UiTree<'_, '_, M> {
249 #[inline]
250 fn get_debug_label(&self, node_id: NodeId) -> &'static str {
251 let node = self.ui_query.get(Entity::from_bits(node_id.into())).unwrap().2;
252 match node.display {
253 Display::Flexbox => "flexbox",
254 Display::Block => "block",
255 Display::None => "none",
256 }
257 }
258
259 #[inline]
260 fn get_final_layout(&self, node_id: NodeId) -> &Layout {
261 &self.outputs[&Entity::from_bits(node_id.into())]
262 }
263}
264
265impl<M: Index<MeasureId, Output = dyn Measurer>> LayoutFlexboxContainer for UiTree<'_, '_, M> {
266 type FlexboxContainerStyle<'a>
267 = WithCtx<&'a Ui>
268 where
269 Self: 'a;
270
271 type FlexboxItemStyle<'a>
272 = WithCtx<&'a Ui>
273 where
274 Self: 'a;
275
276 #[inline]
277 fn get_flexbox_container_style(&self, node_id: NodeId) -> Self::FlexboxContainerStyle<'_> {
278 self.get_core_container_style(node_id)
279 }
280
281 #[inline]
282 fn get_flexbox_child_style(&self, child_node_id: NodeId) -> Self::FlexboxItemStyle<'_> {
283 self.get_core_container_style(child_node_id)
284 }
285}
286
287impl<M: Index<MeasureId, Output = dyn Measurer>> LayoutBlockContainer for UiTree<'_, '_, M> {
288 type BlockContainerStyle<'a>
289 = WithCtx<&'a Ui>
290 where
291 Self: 'a;
292
293 type BlockItemStyle<'a>
294 = WithCtx<&'a Ui>
295 where
296 Self: 'a;
297
298 #[inline]
299 fn get_block_container_style(&self, node_id: NodeId) -> Self::BlockContainerStyle<'_> {
300 self.get_core_container_style(node_id)
301 }
302
303 #[inline]
304 fn get_block_child_style(&self, child_node_id: NodeId) -> Self::BlockItemStyle<'_> {
305 self.get_core_container_style(child_node_id)
306 }
307}
308
309impl<M> CacheTree for UiTree<'_, '_, M> {
310 #[inline]
311 fn cache_get(
312 &self,
313 node_id: NodeId,
314 known_dimensions: Size<Option<f32>>,
315 available_space: Size<AvailableSpace>,
316 run_mode: RunMode,
317 ) -> Option<LayoutOutput> {
318 self.cache_query
319 .get(Entity::from_bits(node_id.into()))
320 .ok()
321 .and_then(|cache| cache.0.get(known_dimensions, available_space, run_mode))
322 }
323
324 #[inline]
325 fn cache_store(
326 &mut self,
327 node_id: NodeId,
328 known_dimensions: Size<Option<f32>>,
329 available_space: Size<AvailableSpace>,
330 run_mode: RunMode,
331 layout_output: LayoutOutput,
332 ) {
333 let e = Entity::from_bits(node_id.into());
334 if let Ok(mut cache) = self.cache_query.get_mut(e) {
335 cache.0.store(known_dimensions, available_space, run_mode, layout_output)
336 }
337 }
338
339 #[inline]
340 fn cache_clear(&mut self, node_id: NodeId) {
341 if let Ok(mut cache) = self.cache_query.get_mut(Entity::from_bits(node_id.into())) {
342 cache.0.clear()
343 }
344 }
345}
346
347pub(crate) fn compute_ui_tree(
348 world: &mut World,
349 compute_state: &mut SystemState<(
350 Query<(Ref<UiRootTrns>, &Children, Has<UiUnrounded>)>,
351 Query<(Entity, Has<UiRootTrns>, Read<Ui>, Option<Read<ContentSize>>)>,
352 Query<Read<Children>>,
353 Query<Write<IntermediateUi>>,
354 Query<Write<UiCache>>,
355 )>,
356 propagate_state: &mut SystemState<(
357 Query<(&UiRootTrns, &Children)>,
358 Query<(&mut Transform, &ComputedUi)>,
359 Query<&Children>,
360 )>,
361 mut outputs: Local<EntityHashMap<Layout>>,
362) {
363 world.resource_scope(|world, mut measurers: Mut<Measurements>| {
364 {
365 let cell = world.as_unsafe_world_cell();
366
367 compute_state.update_archetypes_unsafe_world_cell(cell);
368 let ((root_query, ui_query, children_query, intermediate_query, cache_query), measurements) =
369 unsafe { (compute_state.get_unchecked_manual(cell), measurers.get_measurers(cell)) };
370
371 let mut tree = UiTree {
372 measurements,
373 viewport_size: Vec2::ZERO,
374 ui_query,
375 children_query,
376 intermediate_query,
377 cache_query,
378 outputs: &mut outputs,
379 };
380
381 for (trns, roots, is_unrounded) in &root_query {
382 tree.viewport_size = trns.size;
383 for &root in roots {
384 if trns.is_changed() || tree.cache_query.get_mut(root).is_ok_and(|cache| cache.is_changed()) {
385 let node_id = NodeId::from(root.to_bits());
386 compute_root_layout(&mut tree, node_id, Size {
387 width: AvailableSpace::Definite(trns.size.x),
388 height: AvailableSpace::Definite(trns.size.y),
389 });
390
391 if !is_unrounded {
392 round_layout(&mut tree, node_id)
393 }
394 }
395 }
396 }
397 }
398
399 world.insert_batch(outputs.drain().map(|(e, layout)| {
400 (e, ComputedUi {
401 order: layout.order,
402 location: Vec2::new(layout.location.x, layout.location.y),
403 size: Vec2::new(layout.size.width, layout.size.height),
404 content_size: Vec2::new(layout.content_size.width, layout.content_size.height),
405 scrollbar_size: Vec2::new(layout.scrollbar_size.width, layout.scrollbar_size.height),
406 border: layout.border.into(),
407 padding: layout.padding.into(),
408 margin: layout.margin.into(),
409 })
410 }));
411
412 measurers.apply_measurers(world);
413
414 let (root_query, mut query, children_query) = propagate_state.get_mut(world);
415 for (trns, roots) in &root_query {
416 propagate(trns.transform, trns.size.y, roots, &mut query, &children_query);
417 }
418
419 fn propagate(
420 root_transform: Transform,
421 parent_height: f32,
422 entities: &[Entity],
423 query: &mut Query<(&mut Transform, &ComputedUi)>,
424 children_query: &Query<&Children>,
425 ) {
426 for &e in entities {
427 let Ok((mut trns, layout)) = query.get_mut(e) else { continue };
428 let pos = Vec3::new(layout.location.x, parent_height - layout.location.y - layout.size.y, 0.001);
429
430 trns.set_if_neq(root_transform * Transform::from_translation(pos));
431 if let Ok(children) = children_query.get(e) {
432 propagate(Transform::IDENTITY, layout.size.y, children, query, children_query)
433 }
434 }
435 }
436 })
437}