use std::collections::HashSet;
use egui::Pos2;
use petgraph::{
csr::IndexType,
stable_graph::NodeIndex,
Direction::{Incoming, Outgoing},
EdgeType,
};
use serde::{Deserialize, Serialize};
use crate::{
layouts::{Layout, LayoutState},
DisplayEdge, DisplayNode, Graph,
};
#[derive(Debug, Clone, Copy, Serialize, Deserialize, PartialEq, Eq, Default)]
pub enum Orientation {
#[default]
TopDown,
LeftRight,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct State {
pub triggered: bool,
pub row_dist: f32,
pub col_dist: f32,
pub center_parent: bool,
pub orientation: Orientation,
}
impl Default for State {
fn default() -> Self {
Self {
triggered: false,
row_dist: 50.0,
col_dist: 50.0,
center_parent: false,
orientation: Orientation::TopDown,
}
}
}
impl LayoutState for State {}
#[derive(Debug, Default)]
pub struct Hierarchical {
state: State,
}
impl Layout<State> for Hierarchical {
fn next<N, E, Ty, Ix, Dn, De>(&mut self, g: &mut Graph<N, E, Ty, Ix, Dn, De>, _: &egui::Ui)
where
N: Clone,
E: Clone,
Ty: EdgeType,
Ix: IndexType,
Dn: DisplayNode<N, E, Ty, Ix>,
De: DisplayEdge<N, E, Ty, Ix, Dn>,
{
if self.state.triggered {
return;
}
let mut visited = HashSet::new();
let mut next_col: usize = 0;
let roots: Vec<NodeIndex<Ix>> = g.g().externals(Incoming).collect();
for root in &roots {
if visited.contains(root) {
continue;
}
let curr_max_col = layout_tree(g, &mut visited, root, &self.state, 0, next_col);
next_col = curr_max_col + 1;
}
let all_nodes: Vec<NodeIndex<Ix>> = g.g().node_indices().collect();
for n in &all_nodes {
if visited.contains(n) {
continue;
}
let curr_max_col = layout_tree(g, &mut visited, n, &self.state, 0, next_col);
next_col = curr_max_col + 1;
}
self.state.triggered = true;
}
fn state(&self) -> State {
self.state.clone()
}
fn from_state(state: State) -> impl Layout<State> {
Hierarchical { state }
}
}
fn layout_tree<N, E, Ty, Ix, Dn, De>(
g: &mut Graph<N, E, Ty, Ix, Dn, De>,
visited: &mut HashSet<NodeIndex<Ix>>,
root_idx: &NodeIndex<Ix>,
state: &State,
start_row: usize,
start_col: usize,
) -> usize
where
N: Clone,
E: Clone,
Ty: EdgeType,
Ix: IndexType,
Dn: DisplayNode<N, E, Ty, Ix>,
De: DisplayEdge<N, E, Ty, Ix, Dn>,
{
if !visited.contains(root_idx) {
visited.insert(*root_idx);
}
let children: Vec<NodeIndex<Ix>> = g.g().neighbors_directed(*root_idx, Outgoing).collect();
let mut max_col = start_col;
let mut child_col = start_col;
for neighbour_idx in children.iter() {
if visited.contains(neighbour_idx) {
continue;
}
visited.insert(*neighbour_idx);
let child_max_col = layout_tree(g, visited, neighbour_idx, state, start_row + 1, child_col);
if child_max_col > max_col {
max_col = child_max_col;
}
child_col = child_max_col.saturating_add(1);
}
let place_col = start_col;
let (x, y) = match state.orientation {
Orientation::TopDown => (
(place_col as f32) * state.col_dist,
(start_row as f32) * state.row_dist,
),
Orientation::LeftRight => (
(start_row as f32) * state.row_dist,
(place_col as f32) * state.col_dist,
),
};
let node = &mut g.g_mut()[*root_idx];
node.set_location(Pos2::new(x, y));
max_col
}