use crate::graph::VisualGraph;
use std::collections::{HashMap, HashSet, VecDeque};
use serde::Deserialize;
pub type LayoutMap = HashMap<usize, (f64, f64)>;
#[derive(Debug, Clone, Deserialize)]
pub enum LayoutStrategy {
TreeTopDown,
Circular,
Linear,
LayeredBfs,
ForceDirected,
Radial,
}
pub fn compute_layout(
graph: &VisualGraph,
strategy: &LayoutStrategy,
spacing: f32,
) -> LayoutMap {
match strategy {
LayoutStrategy::TreeTopDown => tree_top_down(graph, spacing),
LayoutStrategy::Circular => circular_layout(graph, spacing),
LayoutStrategy::Linear => linear_layout(graph, spacing),
LayoutStrategy::LayeredBfs => layered_bfs(graph, spacing),
LayoutStrategy::ForceDirected => force_directed(graph, spacing),
LayoutStrategy::Radial => radial_layout(graph, spacing),
}
}
fn tree_top_down(graph: &VisualGraph, _spacing: f32) -> LayoutMap {
let mut layout = HashMap::new();
let mut levels: HashMap<usize, usize> = HashMap::new();
let mut children: HashMap<usize, Vec<usize>> = HashMap::new();
let mut parents: HashMap<usize, usize> = HashMap::new();
for edge in &graph.edges {
parents.insert(edge.to, edge.from);
children.entry(edge.from).or_default().push(edge.to);
}
let root_candidates: HashSet<_> = graph.nodes.iter().map(|n| n.id).collect();
let non_roots: HashSet<_> = parents.keys().copied().collect();
let root_ids: Vec<_> = root_candidates.difference(&non_roots).cloned().collect();
fn assign_levels(id: usize, level: usize, levels: &mut HashMap<usize, usize>, children: &HashMap<usize, Vec<usize>>) {
if levels.contains_key(&id) {
return; }
levels.insert(id, level);
if let Some(kids) = children.get(&id) {
for &kid in kids {
assign_levels(kid, level + 1, levels, children);
}
}
}
if root_ids.is_empty() {
for node in &graph.nodes {
levels.insert(node.id, 0);
}
} else {
for &root in &root_ids {
assign_levels(root, 0, &mut levels, &children);
}
}
let mut level_groups: HashMap<usize, Vec<usize>> = HashMap::new();
for (&id, &level) in &levels {
level_groups.entry(level).or_default().push(id);
}
let y_step = 100.0;
let x_step = 80.0;
for (level, ids) in level_groups.iter() {
let y = 50.0 + (*level as f64) * y_step;
let count = ids.len();
for (i, id) in ids.iter().enumerate() {
let x = 100.0 + (i as f64) * x_step + ((300.0 - x_step * count as f64 / 2.0).max(0.0));
layout.insert(*id, (x, y));
}
}
layout
}
fn circular_layout(graph: &VisualGraph, _spacing: f32) -> LayoutMap {
let radius = 200.0;
let center_x = 300.0;
let center_y = 300.0;
let n = graph.nodes.len();
let mut layout = HashMap::new();
for (i, node) in graph.nodes.iter().enumerate() {
let angle = (i as f64) * (2.0 * std::f64::consts::PI) / (n as f64);
let x = center_x + radius * angle.cos();
let y = center_y + radius * angle.sin();
layout.insert(node.id, (x, y));
}
layout
}
fn layered_bfs(graph: &VisualGraph, _spacing: f32) -> LayoutMap {
let mut layout = HashMap::new();
let mut visited = HashSet::new();
let mut levels: HashMap<usize, usize> = HashMap::new();
let mut q = VecDeque::new();
if let Some(start_node) = graph.nodes.first() {
q.push_back((start_node.id, 0));
visited.insert(start_node.id);
}
while let Some((id, level)) = q.pop_front() {
levels.insert(id, level);
for edge in graph.edges.iter().filter(|e| e.from == id) {
if visited.insert(edge.to) {
q.push_back((edge.to, level + 1));
}
}
}
let y_step = 100.0;
let x_step = 80.0;
let mut level_groups: HashMap<usize, Vec<usize>> = HashMap::new();
for (&id, &level) in &levels {
level_groups.entry(level).or_default().push(id);
}
for (level, ids) in level_groups.iter() {
let y = 50.0 + (*level as f64) * y_step;
let count = ids.len();
for (i, id) in ids.iter().enumerate() {
let x = 100.0 + (i as f64) * x_step + ((300.0 - x_step * count as f64 / 2.0).max(0.0));
layout.insert(*id, (x, y));
}
}
layout
}
fn linear_layout(graph: &VisualGraph, _spacing: f32) -> LayoutMap {
circular_layout(graph, _spacing)
}
fn force_directed(graph: &VisualGraph, _spacing: f32) -> LayoutMap {
circular_layout(graph, _spacing)
}
fn radial_layout(graph: &VisualGraph, _spacing: f32) -> LayoutMap {
circular_layout(graph, _spacing)
}