use crate::types::position::CoordinateExtent;
use std::collections::HashMap;
use crate::animation::viewport_animation::ViewportAnimation;
use crate::config::FlowConfig;
use crate::graph::utils::get_nodes_bounds;
use crate::types::changes::{EdgeChange, NodeChange};
use crate::types::connection::ConnectionState;
use crate::types::edge::Edge;
use crate::types::handle::{Handle, HandleType};
use crate::types::node::{InternalNode, Node, NodeHandleBounds, NodeId, NodeInternals};
use crate::types::position::Position;
use crate::types::viewport::Viewport;
use super::changes::{apply_edge_changes, apply_node_changes};
pub struct FlowState<ND = (), ED = ()> {
pub nodes: Vec<Node<ND>>,
pub edges: Vec<Edge<ED>>,
pub node_lookup: HashMap<NodeId, InternalNode<ND>>,
pub viewport: Viewport,
pub connection_state: ConnectionState,
pub selection_rect: Option<egui::Rect>,
pub config: FlowConfig,
pub viewport_animation: Option<ViewportAnimation>,
pub has_animated_edges: bool,
sorted_ids_cache: Option<Vec<NodeId>>,
}
impl<ND: Clone, ED: Clone> FlowState<ND, ED> {
pub fn new(config: FlowConfig) -> Self {
Self {
nodes: Vec::new(),
edges: Vec::new(),
node_lookup: HashMap::new(),
viewport: Viewport::default(),
connection_state: ConnectionState::None,
selection_rect: None,
config,
viewport_animation: None,
has_animated_edges: false,
sorted_ids_cache: None,
}
}
pub fn add_node(&mut self, node: Node<ND>) {
self.nodes.push(node);
self.rebuild_lookup();
}
pub fn add_nodes(&mut self, nodes: impl IntoIterator<Item = Node<ND>>) {
let before = self.nodes.len();
self.nodes.extend(nodes);
if self.nodes.len() > before {
self.rebuild_lookup();
}
}
pub fn add_edge(&mut self, edge: Edge<ED>) {
if edge.animated {
self.has_animated_edges = true;
}
self.edges.push(edge);
}
pub fn add_edges(&mut self, edges: impl IntoIterator<Item = Edge<ED>>) {
for edge in edges {
if edge.animated {
self.has_animated_edges = true;
}
self.edges.push(edge);
}
}
pub fn apply_node_changes(&mut self, changes: &[NodeChange<ND>]) {
let has_structural = changes.iter().any(|c| {
matches!(
c,
NodeChange::Add { .. } | NodeChange::Remove { .. } | NodeChange::Replace { .. }
)
});
if has_structural {
apply_node_changes(changes, &mut self.nodes);
self.rebuild_lookup();
} else {
apply_node_changes(changes, &mut self.nodes);
self.apply_incremental_lookup_updates(changes);
}
}
fn apply_incremental_lookup_updates(&mut self, changes: &[NodeChange<ND>]) {
let mut needs_parent_update = false;
for change in changes {
match change {
NodeChange::Position { id, position, dragging } => {
if let Some(internal) = self.node_lookup.get_mut(id) {
if let Some(pos) = position {
internal.node.position = *pos;
internal.internals.position_absolute = *pos;
needs_parent_update = true;
}
if let Some(d) = dragging {
internal.node.dragging = *d;
}
}
}
NodeChange::Dimensions { id, dimensions } => {
if let Some(internal) = self.node_lookup.get_mut(id) {
internal.node.measured = *dimensions;
if let Some(d) = dimensions {
internal.node.width = Some(d.width);
internal.node.height = Some(d.height);
}
internal.internals.handle_bounds =
build_handle_bounds(&internal.node, &self.config);
}
}
NodeChange::Select { id, selected } => {
if let Some(internal) = self.node_lookup.get_mut(id) {
internal.node.selected = *selected;
}
}
_ => {} }
}
if needs_parent_update {
self.update_absolute_positions();
}
}
pub fn apply_edge_changes(&mut self, changes: &[EdgeChange<ED>]) {
apply_edge_changes(changes, &mut self.edges);
self.has_animated_edges = self.edges.iter().any(|e| e.animated);
}
pub fn rebuild_lookup(&mut self) {
self.sorted_ids_cache = None; self.node_lookup.clear();
for node in &self.nodes {
let handle_bounds = build_handle_bounds(node, &self.config);
let internal = InternalNode {
internals: NodeInternals {
position_absolute: node.position,
z: node.z_index.unwrap_or(0),
handle_bounds,
},
node: node.clone(),
};
self.node_lookup.insert(node.id.clone(), internal);
}
self.update_absolute_positions();
}
fn update_absolute_positions(&mut self) {
if !self.nodes.iter().any(|n| n.parent_id.is_some()) {
return;
}
let parent_map: Vec<(NodeId, NodeId)> = self
.nodes
.iter()
.filter_map(|n| n.parent_id.as_ref().map(|pid| (n.id.clone(), pid.clone())))
.collect();
let parent_data: HashMap<&NodeId, (egui::Pos2, i32)> = parent_map
.iter()
.filter_map(|(_, pid)| {
self.node_lookup
.get(pid)
.map(|p| (pid, (p.internals.position_absolute, p.internals.z)))
})
.collect();
for (child_id, parent_id) in &parent_map {
if let Some(&(parent_pos, parent_z)) = parent_data.get(parent_id) {
if let Some(child) = self.node_lookup.get_mut(child_id) {
child.internals.position_absolute = egui::pos2(
parent_pos.x + child.node.position.x,
parent_pos.y + child.node.position.y,
);
if child.internals.z <= parent_z {
child.internals.z = parent_z + 1;
}
}
}
}
}
pub fn fit_view(&mut self, canvas_rect: egui::Rect, padding: f32, current_time: f64) {
let bounds = get_nodes_bounds(&self.node_lookup);
if bounds == egui::Rect::NOTHING {
return;
}
let target = crate::graph::utils::get_viewport_for_bounds(
bounds,
canvas_rect.width(),
canvas_rect.height(),
self.config.min_zoom,
self.config.max_zoom,
padding,
);
self.animate_viewport(target, current_time);
}
pub fn fit_bounds(
&mut self,
bounds: CoordinateExtent,
canvas_rect: egui::Rect,
padding: f32,
current_time: f64,
) {
let flow_rect = egui::Rect::from_min_max(bounds.min, bounds.max);
if flow_rect.width() <= 0.0 || flow_rect.height() <= 0.0 {
return;
}
let target = crate::graph::utils::get_viewport_for_bounds(
flow_rect,
canvas_rect.width(),
canvas_rect.height(),
self.config.min_zoom,
self.config.max_zoom,
padding,
);
self.animate_viewport(target, current_time);
}
pub fn fit_selected_nodes(&mut self, canvas_rect: egui::Rect, padding: f32, current_time: f64) {
let bounds = self
.node_lookup
.values()
.filter(|n| n.node.selected && !n.node.hidden)
.fold(egui::Rect::NOTHING, |acc, n| acc.union(n.rect()));
if bounds == egui::Rect::NOTHING {
return;
}
let target = crate::graph::utils::get_viewport_for_bounds(
bounds,
canvas_rect.width(),
canvas_rect.height(),
self.config.min_zoom,
self.config.max_zoom,
padding,
);
self.animate_viewport(target, current_time);
}
pub fn zoom_in(&mut self, current_time: f64) {
let target = Viewport {
zoom: (self.viewport.zoom * 1.2).min(self.config.max_zoom),
..self.viewport
};
self.animate_viewport(target, current_time);
}
pub fn zoom_out(&mut self, current_time: f64) {
let target = Viewport {
zoom: (self.viewport.zoom / 1.2).max(self.config.min_zoom),
..self.viewport
};
self.animate_viewport(target, current_time);
}
pub fn set_center(
&mut self,
x: f32,
y: f32,
zoom: Option<f32>,
canvas_rect: egui::Rect,
current_time: f64,
) {
let target_zoom = zoom
.unwrap_or(self.viewport.zoom)
.clamp(self.config.min_zoom, self.config.max_zoom);
let target = Viewport {
x: canvas_rect.center().x - x * target_zoom,
y: canvas_rect.center().y - y * target_zoom,
zoom: target_zoom,
};
self.animate_viewport(target, current_time);
}
pub fn set_viewport(
&mut self,
target: Viewport,
duration: f32,
easing: fn(f32) -> f32,
current_time: f64,
) {
self.viewport_animation = Some(ViewportAnimation::new(
self.viewport,
target,
duration,
current_time,
easing,
));
}
pub fn animate_viewport(&mut self, target: Viewport, current_time: f64) {
self.viewport_animation = Some(ViewportAnimation::new(
self.viewport,
target,
self.config.default_transition_duration,
current_time,
self.config.default_transition_easing,
));
}
pub fn tick_animation(&mut self, current_time: f64) -> bool {
if let Some(ref mut anim) = self.viewport_animation {
self.viewport = anim.tick(current_time);
if !anim.active {
self.viewport_animation = None;
return false;
}
return true;
}
false
}
pub fn sorted_node_ids(&mut self) -> Vec<NodeId> {
if let Some(ref cached) = self.sorted_ids_cache {
return cached.clone();
}
let mut ids = Vec::with_capacity(self.node_lookup.len());
ids.extend(self.node_lookup.keys().cloned());
ids.sort_by_key(|id| self.node_lookup.get(id).map(|n| n.internals.z).unwrap_or(0));
self.sorted_ids_cache = Some(ids.clone());
ids
}
}
fn build_handle_bounds<D>(node: &Node<D>, config: &FlowConfig) -> NodeHandleBounds {
let node_w = node.width.unwrap_or(config.default_node_width);
let node_h = node.height.unwrap_or(config.default_node_height);
let handle_size = config.handle_size;
let source_handles: Vec<_> = node
.handles
.iter()
.filter(|h| h.handle_type == HandleType::Source)
.collect();
let target_handles: Vec<_> = node
.handles
.iter()
.filter(|h| h.handle_type == HandleType::Target)
.collect();
let mut source = Vec::with_capacity(source_handles.len());
let mut target = Vec::with_capacity(target_handles.len());
for nh in source_handles.iter() {
let count = source_handles
.iter()
.filter(|h| h.position == nh.position)
.count();
let idx = source_handles
.iter()
.filter(|h| h.position == nh.position)
.position(|h| std::ptr::eq(*h, *nh))
.unwrap_or(0);
let (x, y) = compute_handle_offset(nh.position, node_w, node_h, handle_size, count, idx);
source.push(Handle {
id: nh.id.clone(),
node_id: node.id.0.clone(), x,
y,
position: nh.position,
handle_type: HandleType::Source,
width: handle_size,
height: handle_size,
});
}
for nh in target_handles.iter() {
let count = target_handles
.iter()
.filter(|h| h.position == nh.position)
.count();
let idx = target_handles
.iter()
.filter(|h| h.position == nh.position)
.position(|h| std::ptr::eq(*h, *nh))
.unwrap_or(0);
let (x, y) = compute_handle_offset(nh.position, node_w, node_h, handle_size, count, idx);
target.push(Handle {
id: nh.id.clone(),
node_id: node.id.0.clone(), x,
y,
position: nh.position,
handle_type: HandleType::Target,
width: handle_size,
height: handle_size,
});
}
NodeHandleBounds { source, target }
}
fn compute_handle_offset(
position: Position,
node_w: f32,
node_h: f32,
handle_size: f32,
count: usize,
index: usize,
) -> (f32, f32) {
let half = handle_size / 2.0;
match position {
Position::Top => {
let spacing = node_w / (count as f32 + 1.0);
let x = spacing * (index as f32 + 1.0) - half;
(x, -half)
}
Position::Bottom => {
let spacing = node_w / (count as f32 + 1.0);
let x = spacing * (index as f32 + 1.0) - half;
(x, node_h - half)
}
Position::Left => {
let spacing = node_h / (count as f32 + 1.0);
let y = spacing * (index as f32 + 1.0) - half;
(-half, y)
}
Position::Right => {
let spacing = node_h / (count as f32 + 1.0);
let y = spacing * (index as f32 + 1.0) - half;
(node_w - half, y)
}
Position::Center => {
(node_w / 2.0 - half, node_h / 2.0 - half)
}
Position::Closest => {
(node_w / 2.0 - half, node_h / 2.0 - half)
}
}
}