use crate::core::types::{BaseGraph, GraphConstructor, NodeId};
use petgraph::EdgeType;
use std::collections::HashMap;
#[derive(Debug, Clone, Copy)]
pub struct NodePosition {
pub x: f64,
pub y: f64,
}
#[derive(Debug, Clone, Copy, PartialEq, Eq, Default)]
pub enum LayoutAlgorithm {
#[default]
ForceDirected,
Circular,
Hierarchical,
Grid,
Random,
}
pub struct LayoutEngine;
impl LayoutEngine {
pub fn compute_layout<A, W, Ty: GraphConstructor<A, W> + EdgeType>(
graph: &BaseGraph<A, W, Ty>,
algorithm: LayoutAlgorithm,
width: f64,
height: f64,
) -> HashMap<NodeId, NodePosition> {
match algorithm {
LayoutAlgorithm::ForceDirected => Self::force_directed_layout(graph, width, height),
LayoutAlgorithm::Circular => Self::circular_layout(graph, width, height),
LayoutAlgorithm::Hierarchical => Self::hierarchical_layout(graph, width, height),
LayoutAlgorithm::Grid => Self::grid_layout(graph, width, height),
LayoutAlgorithm::Random => Self::random_layout(graph, width, height),
}
}
fn force_directed_layout<A, W, Ty: GraphConstructor<A, W> + EdgeType>(
graph: &BaseGraph<A, W, Ty>,
width: f64,
height: f64,
) -> HashMap<NodeId, NodePosition> {
let mut positions = HashMap::new();
let nodes: Vec<_> = graph.nodes().map(|(id, _)| id).collect();
if nodes.is_empty() {
return positions;
}
use rand::Rng;
let mut rng = rand::rng();
for node in &nodes {
positions.insert(
*node,
NodePosition {
x: rng.random_range(0.0..width),
y: rng.random_range(0.0..height),
},
);
}
let area = width * height;
let k = (area / nodes.len() as f64).sqrt();
let iterations = 50;
let mut temperature = width.max(height) / 10.0;
let cooling_factor = 0.95;
for _ in 0..iterations {
let mut displacements = HashMap::new();
for &node in &nodes {
displacements.insert(node, (0.0, 0.0));
}
for i in 0..nodes.len() {
let mut dx = 0.0;
let mut dy = 0.0;
for j in 0..nodes.len() {
if i != j {
let pos_i = positions[&nodes[i]];
let pos_j = positions[&nodes[j]];
let delta_x = pos_i.x - pos_j.x;
let delta_y = pos_i.y - pos_j.y;
let distance = (delta_x * delta_x + delta_y * delta_y).sqrt().max(0.01);
let force = k * k / distance;
dx += (delta_x / distance) * force;
dy += (delta_y / distance) * force;
}
}
if let Some((dx_curr, dy_curr)) = displacements.get_mut(&nodes[i]) {
*dx_curr += dx;
*dy_curr += dy;
}
}
for (src, tgt, _) in graph.edges() {
if let (Some(&pos_src), Some(&pos_tgt)) = (positions.get(&src), positions.get(&tgt))
{
let delta_x = pos_tgt.x - pos_src.x;
let delta_y = pos_tgt.y - pos_src.y;
let distance = (delta_x * delta_x + delta_y * delta_y).sqrt().max(0.01);
let force = distance * distance / k;
if let Some((dx_src, dy_src)) = displacements.get_mut(&src) {
*dx_src += (delta_x / distance) * force;
*dy_src += (delta_y / distance) * force;
}
if let Some((dx_tgt, dy_tgt)) = displacements.get_mut(&tgt) {
*dx_tgt -= (delta_x / distance) * force;
*dy_tgt -= (delta_y / distance) * force;
}
}
}
for node in &nodes {
let (dx, dy) = displacements.get(node).copied().unwrap_or((0.0, 0.0));
if let Some(pos) = positions.get_mut(node) {
let displacement = (dx * dx + dy * dy).sqrt();
if displacement > 0.0 {
let limited = displacement.min(temperature);
pos.x += (dx / displacement) * limited;
pos.y += (dy / displacement) * limited;
pos.x = pos.x.max(0.0).min(width);
pos.y = pos.y.max(0.0).min(height);
}
}
}
temperature *= cooling_factor;
}
positions
}
fn circular_layout<A, W, Ty: GraphConstructor<A, W> + EdgeType>(
graph: &BaseGraph<A, W, Ty>,
width: f64,
height: f64,
) -> HashMap<NodeId, NodePosition> {
let mut positions = HashMap::new();
let nodes: Vec<_> = graph.nodes().map(|(id, _)| id).collect();
if nodes.is_empty() {
return positions;
}
let center_x = width / 2.0;
let center_y = height / 2.0;
let radius = width.min(height) / 2.5;
for (i, node) in nodes.iter().enumerate() {
let angle = 2.0 * std::f64::consts::PI * i as f64 / nodes.len() as f64;
positions.insert(
*node,
NodePosition {
x: center_x + radius * angle.cos(),
y: center_y + radius * angle.sin(),
},
);
}
positions
}
fn hierarchical_layout<A, W, Ty: GraphConstructor<A, W> + EdgeType>(
graph: &BaseGraph<A, W, Ty>,
width: f64,
height: f64,
) -> HashMap<NodeId, NodePosition> {
let mut positions = HashMap::new();
let nodes: Vec<_> = graph.nodes().map(|(id, _)| id).collect();
if nodes.is_empty() {
return positions;
}
let mut layers: Vec<Vec<NodeId>> = Vec::new();
let mut visited = std::collections::HashSet::new();
let mut queue = std::collections::VecDeque::new();
let start_nodes: Vec<_> = nodes
.iter()
.filter(|&&n| graph.in_degree(n).unwrap_or(0) == 0)
.copied()
.collect();
if start_nodes.is_empty() {
queue.push_back(nodes[0]);
} else {
for node in start_nodes {
queue.push_back(node);
}
}
while !queue.is_empty() {
let layer_size = queue.len();
let mut current_layer = Vec::new();
for _ in 0..layer_size {
if let Some(node) = queue.pop_front() {
if visited.insert(node) {
current_layer.push(node);
for neighbor in graph.neighbors(node) {
if !visited.contains(&neighbor) {
queue.push_back(neighbor);
}
}
}
}
}
if !current_layer.is_empty() {
layers.push(current_layer);
}
}
for node in nodes {
if !visited.contains(&node) {
layers.push(vec![node]);
}
}
let layer_height = if layers.len() > 1 {
height / (layers.len() - 1) as f64
} else {
height / 2.0
};
for (layer_idx, layer) in layers.iter().enumerate() {
let y = layer_idx as f64 * layer_height;
let layer_width = if layer.len() > 1 {
width / (layer.len() - 1) as f64
} else {
width / 2.0
};
for (node_idx, &node) in layer.iter().enumerate() {
let x = if layer.len() > 1 {
node_idx as f64 * layer_width
} else {
width / 2.0
};
positions.insert(node, NodePosition { x, y });
}
}
positions
}
fn grid_layout<A, W, Ty: GraphConstructor<A, W> + EdgeType>(
graph: &BaseGraph<A, W, Ty>,
width: f64,
height: f64,
) -> HashMap<NodeId, NodePosition> {
let mut positions = HashMap::new();
let nodes: Vec<_> = graph.nodes().map(|(id, _)| id).collect();
if nodes.is_empty() {
return positions;
}
let cols = (nodes.len() as f64).sqrt().ceil() as usize;
let rows = (nodes.len() as f64 / cols as f64).ceil() as usize;
let cell_width = width / cols as f64;
let cell_height = height / rows as f64;
for (i, node) in nodes.iter().enumerate() {
let row = i / cols;
let col = i % cols;
positions.insert(
*node,
NodePosition {
x: (col as f64 + 0.5) * cell_width,
y: (row as f64 + 0.5) * cell_height,
},
);
}
positions
}
fn random_layout<A, W, Ty: GraphConstructor<A, W> + EdgeType>(
graph: &BaseGraph<A, W, Ty>,
width: f64,
height: f64,
) -> HashMap<NodeId, NodePosition> {
use rand::Rng;
let mut rng = rand::rng();
let mut positions = HashMap::new();
for (node, _) in graph.nodes() {
positions.insert(
node,
NodePosition {
x: rng.random_range(0.0..width),
y: rng.random_range(0.0..height),
},
);
}
positions
}
}