Skip to main content

lore_engine/engine/
force_layout.rs

1//! Force-directed graph layout using a simple spring-embedder algorithm.
2//!
3//! Takes nodes + edges from `WikiGraph` and computes (x, y) positions
4//! for each node. Deterministic (no randomness).
5
6use serde::Serialize;
7use std::collections::HashMap;
8use std::f64::consts::PI;
9
10use super::graph::{GraphEdge, GraphNode};
11
12// ─── Layout parameters ──────────────────────────────────────────────────
13
14const REPULSION: f64 = 5000.0;
15const ATTRACTION: f64 = 0.01;
16const IDEAL_LENGTH: f64 = 100.0;
17const ITERATIONS: usize = 200;
18const DAMPING: f64 = 0.85;
19const MAX_DISPLACEMENT: f64 = 50.0;
20const MIN_DISTANCE: f64 = 1.0;
21
22// ─── Output types ────────────────────────────────────────────────────────
23
24/// A graph node with computed (x, y) position.
25#[derive(Debug, Serialize, Clone)]
26pub struct LayoutNode {
27    pub slug: String,
28    pub title: String,
29    pub is_placeholder: bool,
30    pub x: f64,
31    pub y: f64,
32}
33
34/// Complete layout result: positioned nodes + original edges.
35#[derive(Debug, Serialize, Clone)]
36pub struct GraphLayout {
37    pub nodes: Vec<LayoutNode>,
38    pub edges: Vec<GraphEdge>,
39}
40
41// ─── Internal simulation state ───────────────────────────────────────────
42
43struct SimNode {
44    x: f64,
45    y: f64,
46    vx: f64,
47    vy: f64,
48}
49
50// ─── Public API ──────────────────────────────────────────────────────────
51
52/// Compute a force-directed layout for the given graph data.
53/// Returns positioned nodes and the original edges.
54pub fn layout(nodes: Vec<GraphNode>, edges: Vec<GraphEdge>) -> GraphLayout {
55    let n = nodes.len();
56
57    if n == 0 {
58        return GraphLayout {
59            nodes: vec![],
60            edges,
61        };
62    }
63
64    if n == 1 {
65        return GraphLayout {
66            nodes: vec![LayoutNode {
67                slug: nodes[0].slug.clone(),
68                title: nodes[0].title.clone(),
69                is_placeholder: nodes[0].is_placeholder,
70                x: 0.0,
71                y: 0.0,
72            }],
73            edges,
74        };
75    }
76
77    // Build slug → index lookup
78    let slug_to_idx: HashMap<&str, usize> = nodes
79        .iter()
80        .enumerate()
81        .map(|(i, n)| (n.slug.as_str(), i))
82        .collect();
83
84    // Resolve edges to index pairs
85    let edge_pairs: Vec<(usize, usize)> = edges
86        .iter()
87        .filter_map(|e| {
88            let s = slug_to_idx.get(e.source.as_str())?;
89            let t = slug_to_idx.get(e.target.as_str())?;
90            Some((*s, *t))
91        })
92        .collect();
93
94    // Initialize positions in a circle
95    let radius = (n as f64).sqrt() * IDEAL_LENGTH * 0.5;
96    let mut sim: Vec<SimNode> = (0..n)
97        .map(|i| {
98            let angle = 2.0 * PI * (i as f64) / (n as f64);
99            SimNode {
100                x: radius * angle.cos(),
101                y: radius * angle.sin(),
102                vx: 0.0,
103                vy: 0.0,
104            }
105        })
106        .collect();
107
108    // Run simulation
109    for _ in 0..ITERATIONS {
110        // Repulsion between all pairs
111        for i in 0..n {
112            for j in (i + 1)..n {
113                let dx = sim[i].x - sim[j].x;
114                let dy = sim[i].y - sim[j].y;
115                let dist = dx.hypot(dy).max(MIN_DISTANCE);
116                let force = REPULSION / (dist * dist);
117                let fx = (dx / dist) * force;
118                let fy = (dy / dist) * force;
119                sim[i].vx += fx;
120                sim[i].vy += fy;
121                sim[j].vx -= fx;
122                sim[j].vy -= fy;
123            }
124        }
125
126        // Attraction along edges
127        for &(s, t) in &edge_pairs {
128            let dx = sim[t].x - sim[s].x;
129            let dy = sim[t].y - sim[s].y;
130            let dist = dx.hypot(dy).max(MIN_DISTANCE);
131            let force = ATTRACTION * (dist - IDEAL_LENGTH);
132            let fx = (dx / dist) * force;
133            let fy = (dy / dist) * force;
134            sim[s].vx += fx;
135            sim[s].vy += fy;
136            sim[t].vx -= fx;
137            sim[t].vy -= fy;
138        }
139
140        // Apply velocities with damping and displacement cap
141        for node in &mut sim {
142            node.vx *= DAMPING;
143            node.vy *= DAMPING;
144
145            let disp = node.vx.hypot(node.vy);
146            if disp > MAX_DISPLACEMENT {
147                let scale = MAX_DISPLACEMENT / disp;
148                node.vx *= scale;
149                node.vy *= scale;
150            }
151
152            node.x += node.vx;
153            node.y += node.vy;
154        }
155    }
156
157    // Center the layout around (0, 0)
158    let cx: f64 = sim.iter().map(|n| n.x).sum::<f64>() / n as f64;
159    let cy: f64 = sim.iter().map(|n| n.y).sum::<f64>() / n as f64;
160    for node in &mut sim {
161        node.x -= cx;
162        node.y -= cy;
163    }
164
165    // Build output
166    let layout_nodes: Vec<LayoutNode> = nodes
167        .iter()
168        .zip(sim.iter())
169        .map(|(gn, sn)| LayoutNode {
170            slug: gn.slug.clone(),
171            title: gn.title.clone(),
172            is_placeholder: gn.is_placeholder,
173            x: sn.x,
174            y: sn.y,
175        })
176        .collect();
177
178    GraphLayout {
179        nodes: layout_nodes,
180        edges,
181    }
182}
183
184// ─── Tests ───────────────────────────────────────────────────────────────
185
186#[cfg(test)]
187mod tests {
188    use super::*;
189
190    #[test]
191    fn test_empty_graph() {
192        let result = layout(vec![], vec![]);
193        assert!(result.nodes.is_empty());
194        assert!(result.edges.is_empty());
195    }
196
197    #[test]
198    fn test_single_node() {
199        let nodes = vec![GraphNode {
200            slug: "a".into(),
201            title: "A".into(),
202            is_placeholder: false,
203        }];
204        let result = layout(nodes, vec![]);
205        assert_eq!(result.nodes.len(), 1);
206        assert_eq!(result.nodes[0].x, 0.0);
207        assert_eq!(result.nodes[0].y, 0.0);
208    }
209
210    #[test]
211    fn test_two_connected_nodes_separate() {
212        let nodes = vec![
213            GraphNode { slug: "a".into(), title: "A".into(), is_placeholder: false },
214            GraphNode { slug: "b".into(), title: "B".into(), is_placeholder: false },
215        ];
216        let edges = vec![GraphEdge { source: "a".into(), target: "b".into() }];
217        let result = layout(nodes, edges);
218        assert_eq!(result.nodes.len(), 2);
219        // Nodes should be separated (not on top of each other)
220        let dx = result.nodes[0].x - result.nodes[1].x;
221        let dy = result.nodes[0].y - result.nodes[1].y;
222        let dist = (dx * dx + dy * dy).sqrt();
223        assert!(dist > 10.0, "Nodes should be separated, got dist={dist}");
224    }
225
226    #[test]
227    fn test_layout_is_centered() {
228        let nodes = vec![
229            GraphNode { slug: "a".into(), title: "A".into(), is_placeholder: false },
230            GraphNode { slug: "b".into(), title: "B".into(), is_placeholder: false },
231            GraphNode { slug: "c".into(), title: "C".into(), is_placeholder: false },
232        ];
233        let edges = vec![
234            GraphEdge { source: "a".into(), target: "b".into() },
235            GraphEdge { source: "b".into(), target: "c".into() },
236        ];
237        let result = layout(nodes, edges);
238        let cx: f64 = result.nodes.iter().map(|n| n.x).sum::<f64>() / 3.0;
239        let cy: f64 = result.nodes.iter().map(|n| n.y).sum::<f64>() / 3.0;
240        assert!(cx.abs() < 0.01, "Center X should be ~0, got {cx}");
241        assert!(cy.abs() < 0.01, "Center Y should be ~0, got {cy}");
242    }
243}