use std::collections::HashMap;
use egui::{Context, Id, Pos2, Vec2, pos2};
use crate::ui::TreeizeViewer;
use crate::ui::state::NodeState;
use crate::{NodeId, Treeize};
#[derive(Clone, Copy, Debug, PartialEq)]
pub struct LayoutConfig {
pub horizontal_spacing: f32,
pub vertical_spacing: f32,
pub start_pos: Pos2,
}
impl Default for LayoutConfig {
fn default() -> Self {
LayoutConfig { horizontal_spacing: 200.0, vertical_spacing: 150.0, start_pos: Pos2::ZERO }
}
}
trait NodeDimensionProvider {
fn get_size(&self, node_id: NodeId) -> (f32, f32);
fn get_children(&self, node_id: NodeId) -> Vec<NodeId>;
}
struct TreeizeAdapter<'a, T> {
node_sizes: Option<&'a HashMap<NodeId, Vec2>>,
children_map: HashMap<NodeId, Vec<NodeId>>,
config: LayoutConfig,
_phantom: std::marker::PhantomData<T>,
}
impl<'a, T> TreeizeAdapter<'a, T> {
fn new(
treeize: &Treeize<T>,
node_sizes: Option<&'a HashMap<NodeId, Vec2>>,
has_output: &mut impl FnMut(NodeId) -> bool,
has_input: &mut impl FnMut(NodeId) -> bool,
config: LayoutConfig,
) -> Self {
let mut children_map: HashMap<NodeId, Vec<NodeId>> = HashMap::new();
for (out_pin, in_pin) in treeize.wires() {
let from_node = out_pin.node;
let to_node = in_pin.node;
if has_output(from_node) && has_input(to_node) {
children_map.entry(from_node).or_default().push(to_node);
}
}
TreeizeAdapter { node_sizes, children_map, config, _phantom: std::marker::PhantomData }
}
}
impl<T> NodeDimensionProvider for TreeizeAdapter<'_, T> {
fn get_size(&self, node_id: NodeId) -> (f32, f32) {
if let Some(sizes) = self.node_sizes
&& let Some(size) = sizes.get(&node_id)
{
return (size.x, size.y);
}
(self.config.horizontal_spacing, self.config.vertical_spacing)
}
fn get_children(&self, node_id: NodeId) -> Vec<NodeId> {
self.children_map.get(&node_id).cloned().unwrap_or_default()
}
}
struct LayoutNode {
offset_x: f32,
contour: Vec<(f32, f32)>,
width: f32,
height: f32,
}
struct LayoutState {
nodes: HashMap<NodeId, LayoutNode>,
}
impl LayoutState {
fn new() -> Self {
Self { nodes: HashMap::new() }
}
}
fn first_walk<P: NodeDimensionProvider>(
node_id: NodeId,
provider: &P,
config: &LayoutConfig,
state: &mut LayoutState,
) {
let (w, h) = provider.get_size(node_id);
let children = provider.get_children(node_id);
if children.is_empty() {
state.nodes.insert(
node_id,
LayoutNode { offset_x: 0.0, contour: vec![(-w / 2.0, w / 2.0)], width: w, height: h },
);
return;
}
for child_id in &children {
first_walk(*child_id, provider, config, state);
}
let mut merged_contour: Vec<(f32, f32)> = Vec::new();
let mut child_offsets: Vec<f32> = Vec::new();
let mut current_offset = 0.0;
for (i, child_id) in children.iter().enumerate() {
let child_node = state.nodes.get(child_id).unwrap();
if i == 0 {
merged_contour.clone_from(&child_node.contour);
child_offsets.push(0.0);
} else {
let mut shift = 0.0f32;
for ((_l1, r1), (l2, _r2)) in merged_contour.iter().zip(child_node.contour.iter()) {
let dist = (r1 + config.horizontal_spacing) - l2;
if dist > shift {
shift = dist;
}
}
child_offsets.push(shift);
current_offset = shift;
for (depth, (l, r)) in child_node.contour.iter().enumerate() {
let shifted_l = l + shift;
let shifted_r = r + shift;
if depth < merged_contour.len() {
let (exist_l, exist_r) = merged_contour[depth];
merged_contour[depth] = (exist_l.min(shifted_l), exist_r.max(shifted_r));
} else {
merged_contour.push((shifted_l, shifted_r));
}
}
}
}
let children_center = current_offset / 2.0;
let shift_children_left = -children_center;
let mut final_contour = Vec::new();
final_contour.push((-w / 2.0, w / 2.0));
for (l, r) in merged_contour {
final_contour.push((l + shift_children_left, r + shift_children_left));
}
for (i, child_id) in children.iter().enumerate() {
if let Some(node) = state.nodes.get_mut(child_id) {
node.offset_x = child_offsets[i] + shift_children_left;
}
}
state.nodes.insert(
node_id,
LayoutNode {
offset_x: 0.0, contour: final_contour,
width: w,
height: h,
},
);
}
fn second_walk<P: NodeDimensionProvider>(
node_id: NodeId,
absolute_x: f32,
absolute_y: f32,
state: &LayoutState,
positions: &mut HashMap<NodeId, Pos2>,
provider: &P,
config: &LayoutConfig,
) {
let layout_node = state.nodes.get(&node_id).unwrap();
let center_x = absolute_x + layout_node.offset_x;
let top_left_x = center_x - layout_node.width / 2.0;
positions.insert(node_id, pos2(top_left_x, absolute_y));
let children = provider.get_children(node_id);
let next_y = absolute_y + layout_node.height + config.vertical_spacing;
for child_id in children {
second_walk(child_id, center_x, next_y, state, positions, provider, config);
}
}
#[allow(clippy::implicit_hasher)]
pub fn layout_tree<T>(
treeize: &Treeize<T>,
config: LayoutConfig,
mut has_output: impl FnMut(NodeId) -> bool,
mut has_input: impl FnMut(NodeId) -> bool,
node_sizes: Option<&HashMap<NodeId, Vec2>>,
) -> HashMap<NodeId, Pos2> {
let mut positions = HashMap::new();
let adapter = TreeizeAdapter::new(treeize, node_sizes, &mut has_output, &mut has_input, config);
let mut has_incoming: HashMap<NodeId, bool> = HashMap::new();
for (out_pin, in_pin) in treeize.wires() {
let to_node = in_pin.node;
if has_output(out_pin.node) && has_input(to_node) {
has_incoming.insert(to_node, true);
}
}
let root_nodes: Vec<NodeId> = treeize
.node_ids()
.map(|(node_id, _)| node_id)
.filter(|node_id| !has_incoming.get(node_id).copied().unwrap_or(false))
.collect();
if root_nodes.is_empty() {
for (node_id, _) in treeize.node_ids() {
positions.insert(node_id, config.start_pos);
}
return positions;
}
let mut layout_state = LayoutState::new();
let mut root_x_offset = 0.0;
for root_id in &root_nodes {
first_walk(*root_id, &adapter, &config, &mut layout_state);
let root_node = layout_state.nodes.get(root_id).unwrap();
let tree_width = root_node.contour.iter().map(|(l, r)| r - l).fold(0.0f32, f32::max);
let root_center_x = config.start_pos.x + root_x_offset + tree_width / 2.0;
second_walk(
*root_id,
root_center_x,
config.start_pos.y,
&layout_state,
&mut positions,
&adapter,
&config,
);
root_x_offset += tree_width + config.horizontal_spacing;
}
for (node_id, _) in treeize.node_ids() {
positions.entry(node_id).or_insert(config.start_pos);
}
positions
}
pub fn apply_layout<T, H>(treeize: &mut Treeize<T>, positions: &HashMap<NodeId, Pos2, H>)
where
H: std::hash::BuildHasher,
{
for (node_id, pos) in positions {
if let Some(node) = treeize.get_node_info_mut(*node_id) {
node.pos = *pos;
}
}
}
#[allow(clippy::implicit_hasher)]
pub fn layout_and_apply<T>(
treeize: &mut Treeize<T>,
config: LayoutConfig,
has_output: impl FnMut(NodeId) -> bool,
has_input: impl FnMut(NodeId) -> bool,
node_sizes: Option<&HashMap<NodeId, Vec2>>,
) {
let positions = layout_tree(treeize, config, has_output, has_input, node_sizes);
apply_layout(treeize, &positions);
}
pub fn layout_with_viewer<T, V>(
treeize: &mut Treeize<T>,
viewer: &mut V,
config: LayoutConfig,
ctx: &Context,
treeize_id: Id,
) where
V: TreeizeViewer<T>,
{
let mut node_sizes_map: HashMap<NodeId, Vec2> = HashMap::new();
for (node_id, _) in treeize.node_ids() {
let node_state_id = treeize_id.with(("treeize-node", node_id));
if let Some(node_data) = NodeState::pick_data(ctx, node_state_id) {
node_sizes_map.insert(node_id, node_data.size);
}
}
let node_sizes = if node_sizes_map.is_empty() { None } else { Some(&node_sizes_map) };
let node_info: Vec<(NodeId, bool, bool)> = treeize
.node_ids()
.map(|(node_id, data)| {
let has_out = viewer.has_output(data);
let has_in = viewer.has_input(data);
(node_id, has_out, has_in)
})
.collect();
let mut has_output_map: HashMap<NodeId, bool> = HashMap::new();
let mut has_input_map: HashMap<NodeId, bool> = HashMap::new();
for (node_id, has_out, has_in) in &node_info {
has_output_map.insert(*node_id, *has_out);
has_input_map.insert(*node_id, *has_in);
}
let has_output_fn =
|node_id: NodeId| -> bool { has_output_map.get(&node_id).copied().unwrap_or(false) };
let has_input_fn =
|node_id: NodeId| -> bool { has_input_map.get(&node_id).copied().unwrap_or(false) };
layout_and_apply(treeize, config, has_output_fn, has_input_fn, node_sizes);
}