use bevy::prelude::*;
use bevy_gearbox::active::Active;
use bevy_gearbox::{InitialState, StateMachine};
use bevy_egui::egui;
use std::collections::HashSet;
use crate::editor_state::{EditorState, StateMachinePersistentData, StateMachineTransientData, NodeDragged, NodeContextMenuRequested, TransitionContextMenuRequested, RenderItem, get_entity_name, should_get_selection_boost, TransitionCreationRequested, CreateTransition, SaveStateMachine, draw_arrow, draw_interactive_pill_label, closest_point_on_rect_edge, get_node_display_color, get_transition_color};
use crate::components::{NodeType, LeafNode, ParentNode};
pub fn update_node_types(
editor_state: Res<EditorState>,
mut state_machines: Query<&mut StateMachinePersistentData, With<StateMachine>>,
parent_query: Query<Entity, With<InitialState>>,
leaf_query: Query<Entity, Without<InitialState>>,
children_query: Query<&bevy_gearbox::StateChildren>,
parallel_query: Query<Entity, With<bevy_gearbox::Parallel>>,
) {
if let Some(selected_root) = editor_state.selected_machine {
if let Ok(mut machine_data) = state_machines.get_mut(selected_root) {
let mut descendants: Vec<Entity> = children_query
.iter_descendants_depth_first(selected_root)
.collect();
descendants.insert(0, selected_root);
for &entity in &descendants {
if parent_query.contains(entity) || parallel_query.contains(entity) {
match machine_data.nodes.get(&entity) {
Some(NodeType::Parent(_)) => {
}
Some(NodeType::Leaf(leaf_node)) => {
let parent_node = ParentNode::new(leaf_node.entity_node.position);
machine_data.nodes.insert(entity, NodeType::Parent(parent_node));
}
None => {
let parent_node = ParentNode::new(egui::Pos2::new(200.0, 100.0));
machine_data.nodes.insert(entity, NodeType::Parent(parent_node));
}
}
} else if leaf_query.contains(entity) {
match machine_data.nodes.get(&entity) {
Some(NodeType::Leaf(_)) => {
}
Some(NodeType::Parent(parent_node)) => {
let leaf_node = LeafNode::new(parent_node.entity_node.position);
machine_data.nodes.insert(entity, NodeType::Leaf(leaf_node));
}
None => {
let leaf_node = LeafNode::new(egui::Pos2::new(100.0, 100.0));
machine_data.nodes.insert(entity, NodeType::Leaf(leaf_node));
}
}
}
}
let valid_entities: HashSet<Entity> = descendants.into_iter().collect();
machine_data.nodes.retain(|entity, _| valid_entities.contains(entity));
}
}
}
pub fn show_machine_editor(
ctx: &egui::Context,
editor_state: &mut EditorState,
persistent_data: &mut StateMachinePersistentData,
transient_data: &mut StateMachineTransientData,
all_entities: &Query<(Entity, Option<&Name>, Option<&InitialState>)>,
child_of_query: &Query<&bevy_gearbox::StateChildOf>,
children_query: &Query<&bevy_gearbox::StateChildren>,
active_query: &Query<&Active>,
parallel_query: &Query<&bevy_gearbox::Parallel>,
commands: &mut Commands,
) {
egui::CentralPanel::default().show(ctx, |ui| {
ui.horizontal(|ui| {
if ui.button("← Back to Machine List").clicked() {
editor_state.selected_machine = None;
transient_data.selected_node = None;
}
if let Some(selected_root) = editor_state.selected_machine {
let machine_name = get_entity_name(selected_root, all_entities);
ui.separator();
ui.label(format!("Editing: {}", machine_name));
ui.separator();
if ui.button("💾 Save").clicked() {
commands.trigger(SaveStateMachine { entity: selected_root });
}
}
});
ui.separator();
if let Some(selected_root) = editor_state.selected_machine {
let mut render_queue = Vec::new();
let mut hierarchy_entities: Vec<Entity> = children_query
.iter_descendants_depth_first(selected_root)
.collect();
hierarchy_entities.insert(0, selected_root);
for (hierarchy_index, entity) in hierarchy_entities.iter().enumerate() {
if let Some(_node) = persistent_data.nodes.get(entity) {
let base_z_order = hierarchy_index as i32 * 10;
let selection_boost = if should_get_selection_boost(*entity, transient_data.selected_node, child_of_query) {
5
} else {
0
};
render_queue.push(RenderItem {
entity: *entity,
z_order: base_z_order + selection_boost,
});
}
}
render_queue.sort_by_key(|item| item.z_order);
for render_item in render_queue {
let entity = render_item.entity;
let entity_name = get_entity_name(entity, all_entities);
if let Some(node) = persistent_data.nodes.get_mut(&entity) {
let is_selected = transient_data.selected_node == Some(entity);
let is_root = selected_root == entity;
let is_editing = transient_data.text_editing.is_editing(entity);
let should_focus = transient_data.text_editing.should_focus;
let first_focus = transient_data.text_editing.first_focus;
let node_color = Some(get_node_display_color(entity, active_query, &transient_data.node_pulses));
let response = match node {
NodeType::Leaf(leaf_node) => {
let dotted = is_direct_child_of_parallel(entity, child_of_query, parallel_query);
leaf_node.show_with_border_style(
ui,
&entity_name,
Some(&format!("{:?}", entity)),
is_selected,
is_editing,
&mut transient_data.text_editing.current_text,
should_focus,
first_focus,
node_color,
dotted,
)
}
NodeType::Parent(parent_node) => {
let dotted = is_direct_child_of_parallel(entity, child_of_query, parallel_query);
parent_node.show_with_border_style(
ui,
&entity_name,
Some(&format!("{:?}", entity)),
is_selected,
is_root,
is_editing,
&mut transient_data.text_editing.current_text,
should_focus,
first_focus,
node_color,
dotted,
)
}
};
if should_focus {
transient_data.text_editing.should_focus = false;
}
if first_focus {
transient_data.text_editing.first_focus = false;
}
if response.clicked {
if transient_data.transition_creation.awaiting_target_selection {
let pointer_pos = ui.input(|i| i.pointer.hover_pos().unwrap_or_default());
transient_data.transition_creation.set_target(entity, pointer_pos);
} else {
transient_data.selected_node = Some(entity);
}
}
if response.add_transition_clicked {
commands.trigger(TransitionCreationRequested {
source_entity: entity,
});
}
if response.right_clicked {
let pointer_pos = ui.input(|i| i.pointer.hover_pos().unwrap_or_default());
commands.trigger(NodeContextMenuRequested {
entity,
position: pointer_pos,
});
}
if response.dragged {
commands.trigger(NodeDragged {
entity,
drag_delta: response.drag_delta,
});
}
}
}
update_transition_rectangles(persistent_data, child_of_query);
render_transition_connections(ui, persistent_data, transient_data, child_of_query, commands);
render_initial_state_indicators(ui, persistent_data, &all_entities, selected_root);
if transient_data.transition_creation.awaiting_target_selection {
if ui.input(|i| i.pointer.primary_clicked()) {
let pointer_pos = ui.input(|i| i.pointer.hover_pos().unwrap_or_default());
let clicked_on_node = persistent_data.nodes.values().any(|node| {
node.current_rect().contains(pointer_pos)
});
if !clicked_on_node {
transient_data.transition_creation.cancel();
}
}
}
} else {
ui.label("No state machine selected");
}
handle_text_editing_completion(ui, transient_data, commands);
render_transition_creation_ui(ui, persistent_data, transient_data, commands);
});
}
fn render_transition_creation_ui(
ui: &mut egui::Ui,
persistent_data: &mut StateMachinePersistentData,
transient_data: &mut StateMachineTransientData,
commands: &mut Commands,
) {
if transient_data.transition_creation.awaiting_target_selection {
if let Some(source) = transient_data.transition_creation.source_entity {
if let Some(source_node) = persistent_data.nodes.get(&source) {
let mouse_pos = ui.input(|i| i.pointer.hover_pos().unwrap_or_default());
let source_rect = source_node.current_rect();
let source_edge = closest_point_on_rect_edge(source_rect, mouse_pos);
let painter = ui.painter();
draw_dashed_arrow(&painter, source_edge, mouse_pos, egui::Color32::WHITE);
}
if ui.input(|i| {
i.pointer.secondary_clicked() || i.key_pressed(egui::Key::Escape) }) {
transient_data.transition_creation.cancel();
}
}
}
if transient_data.transition_creation.show_event_dropdown {
if let (Some(source), Some(target), Some(position)) = (
transient_data.transition_creation.source_entity,
transient_data.transition_creation.target_entity,
transient_data.transition_creation.dropdown_position,
) {
let dropdown_id = egui::Id::new("transition_event_dropdown");
egui::Area::new(dropdown_id)
.fixed_pos(position)
.order(egui::Order::Foreground)
.show(ui.ctx(), |ui| {
egui::Frame::popup(ui.style())
.show(ui, |ui| {
ui.set_min_width(200.0);
ui.heading("Select Event Type");
ui.separator();
if transient_data.transition_creation.available_event_types.is_empty() {
ui.label("No EventEdge event types found.");
ui.label("Make sure event types are registered with the type registry.");
} else {
for event_type in &transient_data.transition_creation.available_event_types.clone() {
if ui.button(event_type).clicked() {
commands.trigger(CreateTransition {
source_entity: source,
target_entity: target,
event_type: event_type.clone(),
});
}
}
}
ui.separator();
if ui.button("Cancel").clicked() {
transient_data.transition_creation.cancel();
}
});
});
if ui.input(|i| i.pointer.any_click()) {
let pointer_pos = ui.input(|i| i.pointer.hover_pos().unwrap_or_default());
let dropdown_rect = egui::Rect::from_min_size(position, egui::Vec2::new(200.0, 150.0));
if !dropdown_rect.contains(pointer_pos) {
transient_data.transition_creation.cancel();
}
}
}
}
}
fn render_transition_connections(
ui: &mut egui::Ui,
persistent_data: &mut StateMachinePersistentData,
transient_data: &StateMachineTransientData,
child_of_query: &Query<&bevy_gearbox::StateChildOf>,
commands: &mut Commands,
) {
let transitions_data: Vec<_> = persistent_data.visual_transitions.iter().enumerate().map(|(index, transition)| {
let transition_color = get_transition_color(
transition.source_entity,
transition.target_entity,
&transient_data.transition_pulses
);
(index,
transition.calculate_two_segment_points(),
transition.event_node_position,
transition.event_type.clone(),
transition.is_dragging_event_node,
transition_color)
}).collect();
let painter = ui.painter();
let mut interaction_data = Vec::new();
for (index, (source_start, source_end, target_start, target_end), event_pos, _event_type, _is_dragging, _color) in &transitions_data {
let tconn = &persistent_data.visual_transitions[*index];
let source_rect = tconn.source_rect;
let is_ancestor = is_ancestor_of(tconn.source_entity, tconn.target_entity, child_of_query);
if is_ancestor {
draw_fish_hook_to_point(&painter, source_rect, *event_pos, egui::Color32::WHITE);
draw_arrow(&painter, *event_pos, *target_end, egui::Color32::WHITE);
} else {
draw_arrow(&painter, *source_start, *source_end, egui::Color32::WHITE);
draw_arrow(&painter, *target_start, *target_end, egui::Color32::WHITE);
}
}
for (index, (_source_start, _source_end, _target_start, _target_end), event_pos, event_type, is_dragging, color) in transitions_data {
let font_id = egui::FontId::new(12.0, egui::FontFamily::Proportional);
let response = draw_interactive_pill_label(ui, event_pos, &event_type, font_id, is_dragging, color);
interaction_data.push((index, response));
}
for (index, response) in interaction_data {
let transition = &mut persistent_data.visual_transitions[index];
if response.secondary_clicked() {
let pointer_pos = ui.input(|i| i.pointer.hover_pos().unwrap_or_default());
commands.trigger(TransitionContextMenuRequested {
source_entity: transition.source_entity,
target_entity: transition.target_entity,
event_type: transition.event_type.clone(),
position: pointer_pos,
});
}
if response.drag_started() {
transition.is_dragging_event_node = true;
}
if response.dragged() && transition.is_dragging_event_node {
transition.event_node_position += response.drag_delta();
}
if response.drag_stopped() {
transition.is_dragging_event_node = false;
transition.update_event_node_offset();
}
}
}
fn is_direct_child_of_parallel(
entity: Entity,
child_of_query: &Query<&bevy_gearbox::StateChildOf>,
parallel_query: &Query<&bevy_gearbox::Parallel>,
) -> bool {
if let Ok(child_of) = child_of_query.get(entity) {
return parallel_query.get(child_of.0).is_ok();
}
false
}
fn update_transition_rectangles(
persistent_data: &mut StateMachinePersistentData,
child_of_query: &Query<&bevy_gearbox::StateChildOf>,
) {
let mut node_rects: std::collections::HashMap<Entity, egui::Rect> = std::collections::HashMap::new();
for (entity, node) in &persistent_data.nodes {
node_rects.insert(*entity, node.current_rect());
}
for transition in &mut persistent_data.visual_transitions {
if let Some(r) = node_rects.get(&transition.source_entity) { transition.source_rect = *r; }
if let Some(r) = node_rects.get(&transition.target_entity) { transition.target_rect = *r; }
if !transition.is_dragging_event_node {
transition.update_event_node_position();
constrain_event_node_position(transition, &node_rects, child_of_query);
}
}
}
fn constrain_event_node_position(
transition: &mut crate::TransitionConnection,
node_rects: &std::collections::HashMap<Entity, egui::Rect>,
child_of_query: &Query<&bevy_gearbox::StateChildOf>,
) {
let source_depth = hierarchy_depth_from_pairs(transition.source_entity, child_of_query);
let target_depth = hierarchy_depth_from_pairs(transition.target_entity, child_of_query);
let higher = if source_depth <= target_depth { transition.source_entity } else { transition.target_entity };
let other = if higher == transition.source_entity { transition.target_entity } else { transition.source_entity };
let is_direct_child = match child_of_query.get(other) { Ok(rel) => rel.0 == higher, Err(_) => false };
let parent_for_pill = if is_direct_child { higher } else if let Ok(rel) = child_of_query.get(higher) { rel.0 } else { higher };
if let Some(parent_rect) = node_rects.get(&parent_for_pill) {
let content_rect = egui::Rect::from_min_max(
egui::Pos2::new(parent_rect.min.x, parent_rect.min.y + 30.0),
parent_rect.max,
);
let margin = egui::Vec2::new(10.0, 10.0);
let pill_half = egui::Vec2::new(45.0, 12.0);
let min_allowed = content_rect.min + margin + pill_half;
let max_allowed = content_rect.max - margin - pill_half;
if min_allowed.x <= max_allowed.x && min_allowed.y <= max_allowed.y {
transition.event_node_position = egui::Pos2::new(
transition.event_node_position.x.clamp(min_allowed.x, max_allowed.x),
transition.event_node_position.y.clamp(min_allowed.y, max_allowed.y),
);
}
}
}
fn handle_text_editing_completion(
ui: &mut egui::Ui,
transient_data: &mut StateMachineTransientData,
commands: &mut Commands,
) {
if transient_data.text_editing.editing_entity.is_some() {
let should_complete = ui.input(|i| {
i.key_pressed(egui::Key::Enter) ||
i.key_pressed(egui::Key::Escape) ||
i.pointer.any_click()
});
let is_escape = ui.input(|i| i.key_pressed(egui::Key::Escape));
if should_complete {
if is_escape {
transient_data.text_editing.cancel_editing();
} else if let Some((entity, new_name)) = transient_data.text_editing.stop_editing() {
let trimmed_name = new_name.trim();
if !trimmed_name.is_empty() {
commands.entity(entity).insert(Name::new(trimmed_name.to_string()));
} else {
info!("⚠️ Ignoring empty name for entity {:?}", entity);
}
}
}
}
}
fn render_initial_state_indicators(
ui: &mut egui::Ui,
persistent_data: &StateMachinePersistentData,
all_entities: &Query<(Entity, Option<&Name>, Option<&InitialState>)>,
selected_root: Entity,
) {
let painter = ui.painter();
for (parent_entity, _name, initial_state_opt) in all_entities.iter() {
if let Some(initial_state) = initial_state_opt {
let target_entity = initial_state.0;
if let (Some(_parent_node), Some(target_node)) = (
persistent_data.nodes.get(&parent_entity),
persistent_data.nodes.get(&target_entity)
) {
let belongs_to_current_machine = parent_entity == selected_root ||
all_entities.iter().any(|(entity, _, _)| {
entity == selected_root &&
true });
if belongs_to_current_machine {
render_initial_state_indicator(
&painter,
target_node.current_rect(),
);
}
}
}
}
}
fn render_initial_state_indicator(
painter: &egui::Painter,
target_rect: egui::Rect,
) {
let circle_offset = egui::Vec2::new(-13.0, 1.0);
let circle_center = target_rect.left_top() + circle_offset;
let circle_radius = 3.0;
painter.circle_filled(
circle_center,
circle_radius,
egui::Color32::WHITE,
);
painter.circle_stroke(
circle_center,
circle_radius,
egui::Stroke::new(1.5, egui::Color32::from_rgb(200, 200, 200)),
);
let arrow_start = circle_center + egui::Vec2::new(0.0, circle_radius); let arrow_end = egui::Pos2::new(target_rect.left(), target_rect.top() + 16.0);
let control_point = egui::Pos2::new(
arrow_start.x, arrow_start.y + (arrow_end.y - arrow_start.y) * 0.7, );
let segments = 12;
let mut prev_point = arrow_start;
for i in 1..=segments {
let t = i as f32 / segments as f32;
let current_point = quadratic_bezier(arrow_start, control_point, arrow_end, t);
painter.line_segment(
[prev_point, current_point],
egui::Stroke::new(2.0, egui::Color32::WHITE),
);
prev_point = current_point;
}
let arrowhead_size = 4.0;
let arrowhead_direction = egui::Vec2::new(1.0, 0.0); let perpendicular = egui::Vec2::new(0.0, 1.0);
let arrowhead_point1 = arrow_end - arrowhead_direction * arrowhead_size + perpendicular * (arrowhead_size * 0.5);
let arrowhead_point2 = arrow_end - arrowhead_direction * arrowhead_size - perpendicular * (arrowhead_size * 0.5);
painter.line_segment(
[arrow_end, arrowhead_point1],
egui::Stroke::new(2.0, egui::Color32::WHITE),
);
painter.line_segment(
[arrow_end, arrowhead_point2],
egui::Stroke::new(2.0, egui::Color32::WHITE),
);
}
fn quadratic_bezier(start: egui::Pos2, control: egui::Pos2, end: egui::Pos2, t: f32) -> egui::Pos2 {
let one_minus_t = 1.0 - t;
let one_minus_t_sq = one_minus_t * one_minus_t;
let t_sq = t * t;
egui::Pos2::new(
one_minus_t_sq * start.x + 2.0 * one_minus_t * t * control.x + t_sq * end.x,
one_minus_t_sq * start.y + 2.0 * one_minus_t * t * control.y + t_sq * end.y,
)
}
fn draw_dashed_arrow(painter: &egui::Painter, start: egui::Pos2, end: egui::Pos2, color: egui::Color32) {
let direction = end - start;
let distance = direction.length();
if distance < 1.0 {
return; }
let normalized_direction = direction / distance;
let dash_length = 8.0;
let gap_length = 4.0;
let dash_and_gap = dash_length + gap_length;
let mut current_distance = 0.0;
while current_distance < distance {
let dash_start = start + normalized_direction * current_distance;
let dash_end_distance = (current_distance + dash_length).min(distance);
let dash_end = start + normalized_direction * dash_end_distance;
painter.line_segment(
[dash_start, dash_end],
egui::Stroke::new(2.0, color),
);
current_distance += dash_and_gap;
}
let arrowhead_size = 6.0;
let perpendicular = egui::Vec2::new(-normalized_direction.y, normalized_direction.x);
let arrowhead_point1 = end - normalized_direction * arrowhead_size + perpendicular * (arrowhead_size * 0.5);
let arrowhead_point2 = end - normalized_direction * arrowhead_size - perpendicular * (arrowhead_size * 0.5);
painter.line_segment(
[end, arrowhead_point1],
egui::Stroke::new(2.0, color),
);
painter.line_segment(
[end, arrowhead_point2],
egui::Stroke::new(2.0, color),
);
}
fn is_ancestor_of(source: Entity, target: Entity, child_of_query: &Query<&bevy_gearbox::StateChildOf>) -> bool {
let mut current = target;
while let Ok(child_of) = child_of_query.get(current) {
if child_of.0 == source {
return true;
}
current = child_of.0;
}
false
}
fn draw_fish_hook_to_point(
painter: &egui::Painter,
parent_rect: egui::Rect,
event_pos: egui::Pos2,
color: egui::Color32,
) {
let p0 = closest_point_on_rect_edge(parent_rect, event_pos);
let mut dir = p0 - event_pos;
let len = dir.length();
if len > 1e-3 { dir /= len; } else { dir = egui::Vec2::new(1.0, 0.0); }
let p1 = p0 + dir * 10.0;
let perp = egui::Vec2::new(-dir.y, dir.x);
let p2 = p1 + perp * 10.0;
let p3 = p2 - dir * 10.0;
draw_cubic_bezier(painter, p0, p1, p2, p3, color);
painter.line_segment([p3, event_pos], egui::Stroke::new(2.0, color));
}
fn draw_cubic_bezier(
painter: &egui::Painter,
p0: egui::Pos2,
p1: egui::Pos2,
p2: egui::Pos2,
p3: egui::Pos2,
color: egui::Color32,
) {
let segments = 24;
let mut prev = p0;
for i in 1..=segments {
let t = i as f32 / segments as f32;
let pt = cubic_bezier_point(p0, p1, p2, p3, t);
painter.line_segment([prev, pt], egui::Stroke::new(2.0, color));
prev = pt;
}
}
fn cubic_bezier_point(p0: egui::Pos2, p1: egui::Pos2, p2: egui::Pos2, p3: egui::Pos2, t: f32) -> egui::Pos2 {
let u = 1.0 - t;
let uu = u * u;
let uuu = uu * u;
let tt = t * t;
let ttt = tt * t;
let x = uuu * p0.x + 3.0 * uu * t * p1.x + 3.0 * u * tt * p2.x + ttt * p3.x;
let y = uuu * p0.y + 3.0 * uu * t * p1.y + 3.0 * u * tt * p2.y + ttt * p3.y;
egui::Pos2::new(x, y)
}
fn hierarchy_depth_from_pairs(mut entity: Entity, child_of_query: &Query<&bevy_gearbox::StateChildOf>) -> usize {
let mut depth = 0;
while let Ok(rel) = child_of_query.get(entity) {
depth += 1;
entity = rel.0;
}
depth
}