use std::collections::HashMap;
use bevy_ecs::prelude::*;
use taffy::Overflow;
use taffy::prelude::*;
use crate::point::Point;
use crate::renderer::EmptyMeasurer;
use crate::renderer::Measurer;
use crate::renderer::ScaleFactor;
use crate::renderer::schedule::Measure;
use crate::style::Rectangle;
use crate::widget::text::TextNode;
#[derive(Component, Clone, Debug)]
#[require(FlexDirection, FlexGrow, Height, Width)]
pub struct RootNode;
#[derive(Component, Clone, Debug, Default)]
#[require(FlexDirection, FlexGrow, Height, Width)]
pub struct ContainerNode;
#[derive(Component, Clone, Debug, Default)]
pub struct DisplayNode(pub Display);
#[derive(Component, Clone, Debug, Default)]
pub struct FlexGrow(pub f32);
#[derive(Component, Clone, Debug)]
pub struct FlexBasis(pub taffy::LengthPercentageAuto);
#[derive(Component, Clone, Debug)]
pub struct AlignItemsNode(pub Option<AlignItems>);
#[derive(Component, Clone, Debug)]
pub struct JustifyContentNode(pub Option<JustifyContent>);
#[derive(Component, Clone, Debug)]
pub struct OverflowNode(pub taffy::Point<Overflow>);
#[derive(Component, Clone, Copy, Debug)]
pub struct ScrollNode(pub Point);
#[derive(Component, Clone, Copy, Debug)]
pub struct PaddingNode(pub taffy::Rect<taffy::LengthPercentage>);
unsafe impl Send for PaddingNode {}
unsafe impl Sync for PaddingNode {}
unsafe impl Send for FlexBasis {}
unsafe impl Sync for FlexBasis {}
impl Default for FlexBasis {
fn default() -> Self {
FlexBasis(auto())
}
}
#[derive(Component, Clone, Debug, Default)]
#[require(Gap)]
pub enum FlexDirection {
Column,
#[default]
Row,
}
impl From<&FlexDirection> for taffy::FlexDirection {
fn from(value: &FlexDirection) -> Self {
match value {
FlexDirection::Column => taffy::FlexDirection::Column,
FlexDirection::Row => taffy::FlexDirection::Row,
}
}
}
#[derive(Component, Clone, Debug, Default)]
pub struct GridTemplateRows(pub Vec<GridTemplateComponent<String>>);
unsafe impl Send for GridTemplateRows {}
unsafe impl Sync for GridTemplateRows {}
#[derive(Component, Clone, Debug, Default)]
pub struct GridTemplateColumns(pub Vec<GridTemplateComponent<String>>);
unsafe impl Send for GridTemplateColumns {}
unsafe impl Sync for GridTemplateColumns {}
#[derive(Component, Clone, Debug, Default)]
pub struct PositionNode(pub Position);
#[derive(Component, Clone, Debug)]
pub struct TopNode(pub LengthPercentageAuto);
unsafe impl Send for TopNode {}
unsafe impl Sync for TopNode {}
#[derive(Component, Clone, Debug)]
pub struct RightNode(pub LengthPercentageAuto);
unsafe impl Send for RightNode {}
unsafe impl Sync for RightNode {}
#[derive(Component, Clone, Debug)]
pub struct BottomNode(pub LengthPercentageAuto);
unsafe impl Send for BottomNode {}
unsafe impl Sync for BottomNode {}
#[derive(Component, Clone, Debug)]
pub struct LeftNode(pub LengthPercentageAuto);
unsafe impl Send for LeftNode {}
unsafe impl Sync for LeftNode {}
#[derive(Component, Debug, Copy, Clone)]
pub struct Gap(pub taffy::LengthPercentage);
impl Default for Gap {
fn default() -> Self {
Gap(length(0.0))
}
}
unsafe impl Send for Gap {}
unsafe impl Sync for Gap {}
#[derive(Component, Debug, Copy, Clone)]
pub struct Height(pub taffy::LengthPercentageAuto);
impl Default for Height {
fn default() -> Self {
Height(auto())
}
}
unsafe impl Send for Height {}
unsafe impl Sync for Height {}
#[derive(Component, Debug, Copy, Clone)]
pub struct Width(pub taffy::LengthPercentageAuto);
impl Default for Width {
fn default() -> Self {
Width(auto())
}
}
unsafe impl Send for Width {}
unsafe impl Sync for Width {}
#[derive(Component)]
struct RootTree(TaffyTree<Entity>);
#[derive(Component)]
struct TaffyNodeId(NodeId);
#[derive(Component)]
pub struct TaffyStyle {
pub is_style_dirty: bool,
pub is_content_dirty: bool,
style: taffy::Style,
}
#[derive(Component, Debug)]
pub struct MeasuredLayout(pub Rectangle);
#[derive(Component, Debug)]
pub struct RenderedLayout(pub Rectangle);
#[expect(unsafe_code, reason = "TaffyTree is safe as long as calc is not used")]
unsafe impl Send for RootTree {}
#[expect(unsafe_code, reason = "TaffyTree is safe as long as calc is not used")]
unsafe impl Sync for RootTree {}
#[expect(unsafe_code, reason = "TaffyStyle is safe as long as calc is not used")]
unsafe impl Send for TaffyStyle {}
#[expect(unsafe_code, reason = "TaffyStyle is safe as long as calc is not used")]
unsafe impl Sync for TaffyStyle {}
fn on_insert_root(
root: On<Insert, RootNode>,
mut commands: Commands,
root_tree_query: Query<&RootTree>,
) {
tracing::trace!("Root node inserted: {:?}", root.entity);
if root_tree_query.get(root.entity).is_ok() {
return;
}
let tree = TaffyTree::new();
let empty_measurer = EmptyMeasurer;
let measurer = Measurer(Box::new(empty_measurer));
commands
.entity(root.entity)
.insert((Name::new("Root"), RootTree(tree), measurer));
}
fn on_insert_container(container: On<Insert, ContainerNode>, mut commands: Commands) {
tracing::trace!("Container node inserted: {:?}", container.entity);
let empty_measurer = EmptyMeasurer;
let measurer = Measurer(Box::new(empty_measurer));
commands
.entity(container.entity)
.insert((Name::new("Container"), measurer));
}
type ContainerDisplayQuery<'world, 'state> = Query<
'world,
'state,
(
Entity,
Option<&'static TaffyStyle>,
Option<&'static RootNode>,
Option<&'static DisplayNode>,
Option<&'static FlexDirection>,
Option<&'static FlexGrow>,
Option<&'static GridTemplateRows>,
Option<&'static GridTemplateColumns>,
Option<&'static FlexBasis>,
Option<&'static AlignItemsNode>,
Option<&'static JustifyContentNode>,
Option<&'static OverflowNode>,
Option<&'static Gap>,
),
Or<(With<RootNode>, With<ContainerNode>, With<TextNode>)>,
>;
type ContainerDimensionsQuery<'world, 'state> = Query<
'world,
'state,
(
Option<&'static Height>,
Option<&'static Width>,
Option<&'static PositionNode>,
Option<&'static TopNode>,
Option<&'static RightNode>,
Option<&'static BottomNode>,
Option<&'static LeftNode>,
Option<&'static PaddingNode>,
),
Or<(With<RootNode>, With<ContainerNode>, With<TextNode>)>,
>;
fn pre_measure(
mut commands: Commands,
window_size: Res<WindowSize>,
scale_factor: Res<ScaleFactor>,
containers_display: ContainerDisplayQuery,
containers_dimensions: ContainerDimensionsQuery,
) {
for (
container_entity,
current_style,
is_root,
display,
flex_direction,
flex_grow,
grid_template_rows,
grid_template_columns,
flex_basis,
align_items,
justify_content,
overflow,
gap,
) in &containers_display
{
let (height, width, position, top, right, bottom, left, padding) =
containers_dimensions.get(container_entity).unwrap();
let size = if is_root.is_some() {
Size {
height: length(window_size.0.height() as f32 / scale_factor.0),
width: length(window_size.0.width() as f32 / scale_factor.0),
}
} else {
Size {
width: width.map(|w| w.0.into()).unwrap_or_else(auto),
height: height.map(|h| h.0.into()).unwrap_or_else(auto),
}
};
let default_style = taffy::Style::default();
let mut style = current_style.map(|s| s.style.clone()).unwrap_or_default();
let mut is_style_dirty = false;
if let Some(display) = display {
if style.display != display.0 {
style.display = display.0;
is_style_dirty = true;
}
} else if style.display != default_style.display {
style.display = default_style.display;
is_style_dirty = true;
}
if let Some(position) = position {
if style.position != position.0 {
style.position = position.0;
is_style_dirty = true;
}
} else if style.position != default_style.position {
style.position = default_style.position;
is_style_dirty = true;
}
if let Some(flex_direction) = flex_direction {
let flex_direction: taffy::FlexDirection = flex_direction.into();
if style.flex_direction != flex_direction {
style.flex_direction = flex_direction;
is_style_dirty = true;
}
} else if style.flex_direction != default_style.flex_direction {
style.flex_direction = default_style.flex_direction;
is_style_dirty = true;
}
if let Some(flex_grow) = flex_grow {
if style.flex_grow != flex_grow.0 {
style.flex_grow = flex_grow.0;
is_style_dirty = true;
}
} else if style.flex_grow != default_style.flex_grow {
style.flex_grow = default_style.flex_grow;
is_style_dirty = true;
}
if let Some(flex_basis) = flex_basis {
let flex_basis: Dimension = (flex_basis.0).into();
if style.flex_basis != flex_basis {
style.flex_basis = flex_basis;
is_style_dirty = true;
}
} else if style.flex_basis != default_style.flex_basis {
style.flex_basis = default_style.flex_basis;
is_style_dirty = true;
}
if let Some(align_items) = align_items {
if style.align_items != align_items.0 {
style.align_items = align_items.0;
is_style_dirty = true;
}
} else if style.align_items != default_style.align_items {
style.align_items = default_style.align_items;
is_style_dirty = true;
}
if let Some(justify_content) = justify_content {
if style.justify_content != justify_content.0 {
style.justify_content = justify_content.0;
is_style_dirty = true;
}
} else if style.justify_content != default_style.justify_content {
style.justify_content = default_style.justify_content;
is_style_dirty = true;
}
if let Some(overflow) = overflow {
let overflow: taffy::Point<Overflow> = (overflow.0).into();
if style.overflow != overflow {
style.overflow = overflow;
is_style_dirty = true;
}
} else if style.overflow != default_style.overflow {
style.overflow = default_style.overflow;
is_style_dirty = true;
}
if let Some(gap) = gap {
let gap = Size {
width: gap.0.into(),
height: gap.0.into(),
};
if style.gap != gap {
style.gap = gap;
is_style_dirty = true;
}
} else if style.gap != default_style.gap {
style.gap = default_style.gap;
is_style_dirty = true;
}
if let Some(grid_template_rows) = grid_template_rows {
let grid_template_rows: Vec<GridTemplateComponent<String>> =
grid_template_rows.0.clone();
if style.grid_template_rows != grid_template_rows {
style.grid_template_rows = grid_template_rows;
is_style_dirty = true;
}
} else if style.grid_template_rows != default_style.grid_template_rows {
style.grid_template_rows = default_style.grid_template_rows;
is_style_dirty = true;
}
if let Some(grid_template_columns) = grid_template_columns {
let grid_template_columns: Vec<GridTemplateComponent<String>> =
grid_template_columns.0.clone();
if style.grid_template_columns != grid_template_columns {
style.grid_template_columns = grid_template_columns;
is_style_dirty = true;
}
} else if style.grid_template_columns != default_style.grid_template_columns {
style.grid_template_columns = default_style.grid_template_columns;
is_style_dirty = true;
}
let inset = Rect {
top: top.map(|t| t.0.into()).unwrap_or_else(auto),
right: right.map(|r| r.0.into()).unwrap_or_else(auto),
bottom: bottom.map(|b| b.0.into()).unwrap_or_else(auto),
left: left.map(|l| l.0.into()).unwrap_or_else(auto),
};
if style.inset != inset {
style.inset = inset;
is_style_dirty = true;
}
if style.size != size {
style.size = size;
is_style_dirty = true;
}
if let Some(padding) = padding {
if style.padding != padding.0 {
style.padding = padding.0;
is_style_dirty = true;
}
} else if style.padding != default_style.padding {
style.padding = default_style.padding;
is_style_dirty = true;
}
if scale_factor.is_changed() {
is_style_dirty = true;
}
commands.entity(container_entity).insert(TaffyStyle {
is_style_dirty,
is_content_dirty: current_style.map(|s| s.is_content_dirty).unwrap_or(false),
style,
});
}
}
type ContainerStyleQuery<'world, 'state> = Query<
'world,
'state,
(
&'static Name,
Entity,
Option<&'static TaffyNodeId>,
Option<&'static Children>,
&'static TaffyStyle,
),
Or<(With<ContainerNode>, With<TextNode>)>,
>;
fn measure(
mut commands: Commands,
mut root: Query<
(
Entity,
&Name,
&mut RootTree,
Option<&TaffyNodeId>,
Option<&Children>,
&TaffyStyle,
),
With<RootNode>,
>,
containers_display: ContainerStyleQuery,
mut measurers: Query<&mut Measurer>,
) {
tracing::debug!("Measuring layout");
for (root_entity, root_name, mut root_tree, root_node_id, root_children, root_style) in
&mut root
{
tracing::trace!("Measuring root: {root_name}");
let mut entities = HashMap::new();
let root_node_id = match root_node_id {
Some(root_node_id) => {
if root_style.is_style_dirty {
root_tree
.0
.set_style(root_node_id.0, root_style.style.clone())
.unwrap();
}
if root_style.is_content_dirty {
root_tree.0.mark_dirty(root_node_id.0).unwrap();
}
root_node_id.0
}
None => {
let root_node_id = root_tree.0.new_leaf(root_style.style.clone()).unwrap();
commands
.entity(root_entity)
.insert(TaffyNodeId(root_node_id));
root_node_id
}
};
entities.insert(root_node_id, root_entity);
fn measure_children(
commands: &mut Commands,
root_tree: &mut TaffyTree<Entity>,
containers: &ContainerStyleQuery,
parent_node_id: taffy::NodeId,
children: &Children,
entities: &mut HashMap<taffy::NodeId, Entity>,
) {
for child in children.iter() {
let Ok((child_name, entity, child_node_id, children, child_style)) =
containers.get(child)
else {
continue;
};
tracing::trace!("Adding child: {:?} ({})", entity, child_name);
let child_node_id = match child_node_id {
Some(child_node_id) => {
if child_style.is_style_dirty {
root_tree
.set_style(child_node_id.0, child_style.style.clone())
.unwrap();
}
if child_style.is_content_dirty {
root_tree.mark_dirty(child_node_id.0).unwrap();
}
child_node_id.0
}
None => {
let child_node_id = root_tree
.new_leaf_with_context(child_style.style.clone(), entity)
.unwrap();
root_tree.add_child(parent_node_id, child_node_id).unwrap();
commands.entity(entity).insert(TaffyNodeId(child_node_id));
child_node_id
}
};
entities.insert(child_node_id, entity);
if let Some(children) = children {
measure_children(
commands,
root_tree,
containers,
child_node_id,
children,
entities,
);
}
}
}
if let Some(root_children) = root_children {
measure_children(
&mut commands,
&mut root_tree.0,
&containers_display,
root_node_id,
root_children,
&mut entities,
);
}
root_tree
.0
.compute_layout_with_measure(
root_node_id,
Size::MAX_CONTENT,
|known_dimensions, available_space, _node_id, node_context, _style| {
if let Size {
width: Some(width),
height: Some(height),
} = known_dimensions
{
return Size { width, height };
}
if let Some(entity) = node_context {
let mut measurer = measurers.get_mut(*entity).unwrap();
let inner_size = measurer.0.measure(available_space);
return Size {
width: inner_size.width().into(),
height: inner_size.height().into(),
};
}
Size::zero()
},
)
.unwrap();
fn sync(
commands: &mut Commands,
root_tree: &TaffyTree<Entity>,
parent_node_id: taffy::NodeId,
parent_rectangle: Rectangle,
entities: &HashMap<taffy::NodeId, Entity>,
) {
let parent_entity = entities.get(&parent_node_id).unwrap();
commands
.entity(*parent_entity)
.insert(MeasuredLayout(parent_rectangle));
for child_id in root_tree.children(parent_node_id).unwrap() {
let entity = entities.get(&child_id).unwrap();
let layout = root_tree.layout(child_id).unwrap();
tracing::trace!(
?entity,
"Computed layout for entity: {:?} / {:?}",
layout.size,
layout.location
);
let child_rectangle = Rectangle::from_xywh(
parent_rectangle.x() + layout.location.x,
parent_rectangle.y() + layout.location.y,
layout.size.width,
layout.size.height,
);
sync(commands, root_tree, child_id, child_rectangle, entities);
}
}
let root_layout = root_tree.0.layout(root_node_id).unwrap();
let root_rectangle = Rectangle::from_xywh(
root_layout.location.x,
root_layout.location.y,
root_layout.size.width,
root_layout.size.height,
);
sync(
&mut commands,
&root_tree.0,
root_node_id,
root_rectangle,
&entities,
);
}
}
fn post_measure(mut containers: Query<&mut TaffyStyle>) {
for mut style in &mut containers {
style.is_style_dirty = false;
style.is_content_dirty = false;
}
}
fn on_remove_taffy_node_id(
removed: On<Remove, TaffyNodeId>,
nodes_query: Query<&TaffyNodeId>,
mut root: Query<&mut RootTree>,
) {
for mut root_tree in &mut root {
let taffy_node_id = nodes_query.get(removed.entity).unwrap();
root_tree.0.remove(taffy_node_id.0).unwrap();
}
}
#[derive(Resource, Debug)]
pub struct WindowSize(pub Rectangle);
pub fn layout_setup(world: &mut World) {
world.add_observer(on_insert_root);
world.add_observer(on_insert_container);
world.add_observer(on_remove_taffy_node_id);
world.insert_resource(WindowSize(Rectangle::from_size(600.0, 400.0)));
world
.get_resource_or_init::<Schedules>()
.add_systems(Measure, (pre_measure, measure, post_measure).chain());
}