use rand::Rng;
use ratatui::style::Color;
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
pub enum GraphRelationSemantic {
Neutral,
Focus,
Hover,
Inbound,
Outbound,
Bidirectional,
Related,
SearchMatch,
CascadeSource,
CascadeDimmed,
}
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
pub enum ClusterMapSemantic {
Hovered,
Central,
ThirdParty,
Entrypoint,
Support,
Neutral,
}
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
pub enum OverviewEdgeSemantic {
PrimaryBridge,
ExternalBridge,
ExternalSink,
}
#[derive(Debug, Clone)]
pub struct NodePosition {
pub x: f64,
pub y: f64,
pub prev_x: f64,
pub prev_y: f64,
pub pinned: bool,
}
struct Quadtree {
mass: f64,
com_x: f64,
com_y: f64,
x: f64,
y: f64,
size: f64,
children: Option<Box<[Quadtree; 4]>>,
node_idx: Option<usize>,
depth: usize,
}
const MAX_QUADTREE_DEPTH: usize = 20;
impl Quadtree {
fn new(x: f64, y: f64, size: f64, depth: usize) -> Self {
Self {
mass: 0.0,
com_x: 0.0,
com_y: 0.0,
x,
y,
size,
children: None,
node_idx: None,
depth,
}
}
fn insert(&mut self, idx: usize, px: f64, py: f64) {
if self.mass > 0.0 && self.children.is_none() {
if self.depth >= MAX_QUADTREE_DEPTH {
let new_mass = self.mass + 1.0;
self.com_x = (self.com_x * self.mass + px) / new_mass;
self.com_y = (self.com_y * self.mass + py) / new_mass;
self.mass = new_mass;
return;
}
let half = self.size / 2.0;
let d = self.depth + 1;
let mut children = Box::new([
Quadtree::new(self.x, self.y, half, d),
Quadtree::new(self.x + half, self.y, half, d),
Quadtree::new(self.x, self.y + half, half, d),
Quadtree::new(self.x + half, self.y + half, half, d),
]);
if let Some(old_idx) = self.node_idx.take() {
let ox = self.com_x;
let oy = self.com_y;
let q = self.get_quadrant(ox, oy);
children[q].insert(old_idx, ox, oy);
}
self.children = Some(children);
}
let q = self.get_quadrant(px, py);
if let Some(children) = &mut self.children {
children[q].insert(idx, px, py);
} else {
self.node_idx = Some(idx);
self.com_x = px;
self.com_y = py;
}
let new_mass = self.mass + 1.0;
self.com_x = (self.com_x * self.mass + px) / new_mass;
self.com_y = (self.com_y * self.mass + py) / new_mass;
self.mass = new_mass;
}
fn get_quadrant(&self, px: f64, py: f64) -> usize {
let mid_x = self.x + self.size / 2.0;
let mid_y = self.y + self.size / 2.0;
match (px >= mid_x, py >= mid_y) {
(false, false) => 0,
(true, false) => 1,
(false, true) => 2,
(true, true) => 3,
}
}
fn compute_repulsion(
&self,
idx: usize,
pos: (f64, f64),
theta: f64,
repulsion_const: f64,
force: &mut (f64, f64),
) {
if self.mass == 0.0 || (self.node_idx == Some(idx)) {
return;
}
let dx = self.com_x - pos.0;
let dy = self.com_y - pos.1;
let dist_sq = dx * dx + dy * dy;
let dist = dist_sq.sqrt().max(0.5);
if self.children.is_none() || (self.size / dist < theta) {
let f = (repulsion_const * self.mass) / dist_sq.max(1.0);
force.0 -= f * (dx / dist);
force.1 -= f * (dy / dist);
} else if let Some(children) = &self.children {
for child in children.iter() {
child.compute_repulsion(idx, pos, theta, repulsion_const, force);
}
}
}
}
pub struct GraphLayout {
pub positions: Vec<NodePosition>,
pub edges: Vec<(usize, usize)>,
pub edge_weights: Vec<u32>,
pub labels: Vec<String>,
pub repulsion: f64,
pub attraction: f64,
pub damping: f64,
pub ideal_length: f64,
pub width: f64,
pub height: f64,
pub temperature: f64,
}
impl GraphLayout {
pub fn new(
labels: Vec<String>,
edges: Vec<(usize, usize)>,
edge_weights: Vec<u32>,
width: f64,
height: f64,
) -> Self {
let mut rng = rand::rng();
let n = labels.len();
let radius = width.min(height) * 0.40;
let cx = width / 2.0;
let cy = height / 2.0;
let positions: Vec<NodePosition> = labels
.iter()
.enumerate()
.map(|(i, _)| {
let angle = (i as f64 / n.max(1) as f64) * std::f64::consts::TAU
+ rng.random_range(-0.15..0.15);
let r = radius + rng.random_range(-5.0..5.0);
let x = cx + angle.cos() * r;
let y = cy + angle.sin() * r;
NodePosition {
x,
y,
prev_x: x + rng.random_range(-1.5..1.5),
prev_y: y + rng.random_range(-1.5..1.5),
pinned: false,
}
})
.collect();
let ideal_length = compute_ideal_length(n, width, height);
Self {
positions,
edges,
edge_weights,
labels,
repulsion: 1500.0,
attraction: 0.045,
damping: 0.7, ideal_length,
width,
height,
temperature: 1.0,
}
}
pub fn step(&mut self) {
let n = self.positions.len();
if n == 0 {
return;
}
let mut fx = vec![0.0f64; n];
let mut fy = vec![0.0f64; n];
let q_size = self.width.max(self.height).max(1.0);
let mut qt = Quadtree::new(0.0, 0.0, q_size, 0);
for (i, pos) in self.positions.iter().enumerate() {
qt.insert(i, pos.x, pos.y);
}
let theta = 0.7;
let dynamic_repulsion = if n > 100 {
(self.repulsion * (100.0 / n as f64).sqrt()).max(400.0)
} else {
self.repulsion
};
let repulsion_const = dynamic_repulsion;
for i in 0..n {
let mut f = (0.0, 0.0);
qt.compute_repulsion(
i,
(self.positions[i].x, self.positions[i].y),
theta,
repulsion_const,
&mut f,
);
fx[i] += f.0;
fy[i] += f.1;
}
for &(from, to) in &self.edges {
if from >= n || to >= n {
continue;
}
let dx = self.positions[to].x - self.positions[from].x;
let dy = self.positions[to].y - self.positions[from].y;
let dist = (dx * dx + dy * dy).sqrt().max(0.5);
let displacement = dist - self.ideal_length;
let force = self.attraction * displacement;
let ux = dx / dist;
let uy = dy / dist;
fx[from] += force * ux;
fy[from] += force * uy;
fx[to] -= force * ux;
fy[to] -= force * uy;
}
let cx = self.width / 2.0;
let cy = self.height / 2.0;
let gravity = if n > 200 { 0.08 } else { 0.05 };
for i in 0..n {
let dx = cx - self.positions[i].x;
let dy = cy - self.positions[i].y;
fx[i] += dx * gravity;
fy[i] += dy * gravity;
let margin_soft = 50.0;
let push_strength = 0.25; if self.positions[i].x < margin_soft {
fx[i] += (margin_soft - self.positions[i].x) * push_strength;
} else if self.positions[i].x > self.width - margin_soft {
fx[i] -= (self.positions[i].x - (self.width - margin_soft)) * push_strength;
}
if self.positions[i].y < margin_soft {
fy[i] += (margin_soft - self.positions[i].y) * push_strength;
} else if self.positions[i].y > self.height - margin_soft {
fy[i] -= (self.positions[i].y - (self.height - margin_soft)) * push_strength;
}
}
let mut rng = rand::rng();
let jitter = if self.temperature > 0.5 {
0.5 * self.temperature
} else {
0.0
};
if jitter > 0.001 {
for i in 0..n {
fx[i] += rng.random_range(-jitter..jitter);
fy[i] += rng.random_range(-jitter..jitter);
}
}
let max_disp = 5.0;
for i in 0..n {
if self.positions[i].pinned {
continue;
}
let vx = (self.positions[i].x - self.positions[i].prev_x) * self.damping;
let vy = (self.positions[i].y - self.positions[i].prev_y) * self.damping;
let mut fxi = fx[i] * self.temperature;
let mut fyi = fy[i] * self.temperature;
let vx = vx.clamp(-max_disp, max_disp);
let vy = vy.clamp(-max_disp, max_disp);
fxi = fxi.clamp(-max_disp, max_disp);
fyi = fyi.clamp(-max_disp, max_disp);
let new_x = self.positions[i].x + vx + fxi;
let new_y = self.positions[i].y + vy + fyi;
self.positions[i].prev_x = self.positions[i].x;
self.positions[i].prev_y = self.positions[i].y;
self.positions[i].x = new_x;
self.positions[i].y = new_y;
}
let margin = 8.0;
for pos in &mut self.positions {
if pos.x < margin {
pos.x = margin;
pos.prev_x = margin;
} else if pos.x > self.width - margin {
pos.x = self.width - margin;
pos.prev_x = self.width - margin;
}
if pos.y < margin {
pos.y = margin;
pos.prev_y = margin;
} else if pos.y > self.height - margin {
pos.y = self.height - margin;
pos.prev_y = self.height - margin;
}
}
self.temperature = (self.temperature * 0.997).max(0.001);
}
pub fn multi_step(&mut self, count: usize) {
for _ in 0..count {
self.step();
}
}
pub fn reheat(&mut self) {
self.temperature = 1.5;
}
pub fn reinitialize_positions(&mut self) {
let mut rng = rand::rng();
let n = self.labels.len();
let radius = self.width.min(self.height) * 0.35;
let cx = self.width / 2.0;
let cy = self.height / 2.0;
self.positions = (0..n)
.map(|i| {
let angle = (i as f64 / n.max(1) as f64) * std::f64::consts::TAU
+ rng.random_range(-0.15..0.15);
let r = radius + rng.random_range(-5.0..5.0);
let x = cx + angle.cos() * r;
let y = cy + angle.sin() * r;
NodePosition {
x,
y,
prev_x: x + rng.random_range(-1.5..1.5),
prev_y: y + rng.random_range(-1.5..1.5),
pinned: false,
}
})
.collect();
self.temperature = 1.0;
self.ideal_length = compute_ideal_length(n, self.width, self.height);
}
pub fn center_layout(&mut self) {
let n = self.positions.len();
if n == 0 {
return;
}
let avg_x: f64 = self.positions.iter().map(|p| p.x).sum::<f64>() / n as f64;
let avg_y: f64 = self.positions.iter().map(|p| p.y).sum::<f64>() / n as f64;
let dx = self.width / 2.0 - avg_x;
let dy = self.height / 2.0 - avg_y;
let margin = 8.0;
for pos in &mut self.positions {
pos.x = (pos.x + dx).clamp(margin, self.width - margin);
pos.y = (pos.y + dy).clamp(margin, self.height - margin);
pos.prev_x = (pos.prev_x + dx).clamp(margin, self.width - margin);
pos.prev_y = (pos.prev_y + dy).clamp(margin, self.height - margin);
}
}
pub fn resize(&mut self, width: f64, height: f64) {
if (self.width - width).abs() < 1.0 && (self.height - height).abs() < 1.0 {
return;
}
if (self.width - width).abs() < 0.5 && (self.height - height).abs() < 0.5 {
return;
}
if self.width > 0.0 && self.height > 0.0 {
let sx = width / self.width;
let sy = height / self.height;
for pos in &mut self.positions {
pos.x *= sx;
pos.y *= sy;
pos.prev_x *= sx;
pos.prev_y *= sy;
}
}
self.width = width;
self.height = height;
self.ideal_length = compute_ideal_length(self.labels.len(), width, height);
}
pub fn update_graph(
&mut self,
labels: Vec<String>,
edges: Vec<(usize, usize)>,
edge_weights: Vec<u32>,
) {
use std::collections::HashMap;
let old_positions: HashMap<&str, &NodePosition> = self
.labels
.iter()
.zip(self.positions.iter())
.map(|(l, p)| (l.as_str(), p))
.collect();
let mut rng = rand::rng();
let cx = self.width / 2.0;
let cy = self.height / 2.0;
let new_positions: Vec<NodePosition> = labels
.iter()
.map(|label| {
if let Some(old) = old_positions.get(label.as_str()) {
NodePosition {
x: old.x,
y: old.y,
prev_x: old.prev_x,
prev_y: old.prev_y,
pinned: false,
}
} else {
let x = cx + rng.random_range(-30.0..30.0);
let y = cy + rng.random_range(-30.0..30.0);
NodePosition {
x,
y,
prev_x: x + rng.random_range(-1.0..1.0),
prev_y: y + rng.random_range(-1.0..1.0),
pinned: false,
}
}
})
.collect();
self.positions = new_positions;
self.labels = labels;
self.edges = edges;
self.edge_weights = edge_weights;
self.ideal_length = compute_ideal_length(self.labels.len(), self.width, self.height);
self.temperature = 0.8;
}
}
fn compute_ideal_length(node_count: usize, width: f64, height: f64) -> f64 {
if node_count <= 1 {
return width.min(height) * 0.3;
}
let area = width * height;
let k = 0.7 * (area / node_count as f64).sqrt();
k.clamp(25.0, width.min(height) * 0.45)
}
pub fn drift_color(drift_score: u8) -> Color {
match drift_score {
0..=30 => Color::Rgb(166, 227, 161), 31..=60 => Color::Rgb(249, 226, 175), 61..=80 => Color::Rgb(250, 179, 135), _ => Color::Rgb(243, 139, 168), }
}
pub fn weighted_edge_color(weight: u32) -> Color {
match weight {
1 => Color::Rgb(69, 71, 90), 2..=3 => Color::Rgb(88, 91, 112), 4..=7 => Color::Rgb(116, 199, 236), 8..=15 => Color::Rgb(250, 179, 135), _ => Color::Rgb(243, 139, 168), }
}
pub fn graph_relation_color(role: GraphRelationSemantic) -> Color {
match role {
GraphRelationSemantic::Neutral => Color::Rgb(166, 173, 200),
GraphRelationSemantic::Focus | GraphRelationSemantic::CascadeSource => {
Color::Rgb(255, 232, 115)
}
GraphRelationSemantic::Hover => Color::White,
GraphRelationSemantic::Inbound => Color::Rgb(148, 226, 213),
GraphRelationSemantic::Outbound => Color::Rgb(250, 179, 135),
GraphRelationSemantic::Bidirectional => Color::Rgb(203, 166, 247),
GraphRelationSemantic::Related => Color::Rgb(116, 150, 200),
GraphRelationSemantic::SearchMatch => Color::Rgb(137, 220, 255),
GraphRelationSemantic::CascadeDimmed => Color::Rgb(49, 50, 68),
}
}
pub fn cluster_map_color(role: ClusterMapSemantic) -> Color {
match role {
ClusterMapSemantic::Hovered => graph_relation_color(GraphRelationSemantic::Focus),
ClusterMapSemantic::Central => Color::Rgb(166, 227, 161),
ClusterMapSemantic::ThirdParty => Color::Rgb(125, 173, 189),
ClusterMapSemantic::Entrypoint => Color::Rgb(250, 179, 135),
ClusterMapSemantic::Support => ACCENT_LAVENDER,
ClusterMapSemantic::Neutral => FG_TEXT,
}
}
pub fn overview_edge_color(role: OverviewEdgeSemantic) -> Color {
match role {
OverviewEdgeSemantic::PrimaryBridge => Color::Rgb(166, 227, 161),
OverviewEdgeSemantic::ExternalBridge => Color::Rgb(245, 169, 127),
OverviewEdgeSemantic::ExternalSink => Color::Rgb(108, 112, 136),
}
}
pub fn palette_node_color(index: usize) -> Color {
NODE_PALETTE[index % NODE_PALETTE.len()]
}
pub fn blast_color(score: f64) -> Color {
match (score * 100.0) as u32 {
0..=10 => Color::Rgb(137, 180, 250), 11..=25 => Color::Rgb(148, 226, 213), 26..=45 => Color::Rgb(166, 227, 161), 46..=65 => Color::Rgb(249, 226, 175), 66..=80 => Color::Rgb(250, 179, 135), _ => Color::Rgb(243, 139, 168), }
}
pub fn cascade_distance_color(distance: u32) -> Color {
match distance {
1 => Color::Rgb(243, 139, 168), 2 => Color::Rgb(250, 179, 135), 3 => Color::Rgb(249, 226, 175), 4..=5 => Color::Rgb(166, 227, 161), _ => Color::Rgb(137, 180, 250), }
}
pub const BG_BASE: Color = Color::Rgb(30, 30, 46);
pub const FG_TEXT: Color = Color::Rgb(205, 214, 244);
pub const BG_SURFACE: Color = Color::Rgb(36, 39, 58);
pub const ACCENT_BLUE: Color = Color::Rgb(137, 180, 250);
pub const ACCENT_LAVENDER: Color = Color::Rgb(180, 190, 254);
pub const ACCENT_MAUVE: Color = Color::Rgb(203, 166, 247);
pub const FG_OVERLAY: Color = Color::Rgb(108, 112, 134);
pub const NODE_PALETTE: [Color; 8] = [
Color::Rgb(148, 226, 213), Color::Rgb(137, 180, 250), Color::Rgb(203, 166, 247), Color::Rgb(166, 227, 161), Color::Rgb(249, 226, 175), Color::Rgb(250, 179, 135), Color::Rgb(243, 139, 168), Color::Rgb(180, 190, 254), ];
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_graph_layout_creation() {
let labels = vec!["A".to_string(), "B".to_string(), "C".to_string()];
let edges = vec![(0, 1), (1, 2)];
let weights = vec![1, 2];
let layout = GraphLayout::new(labels.clone(), edges, weights, 200.0, 100.0);
assert_eq!(layout.positions.len(), 3, "Should have 3 node positions");
assert_eq!(layout.labels.len(), 3);
assert_eq!(layout.edges.len(), 2);
assert_eq!(layout.edge_weights.len(), 2);
assert!(layout.ideal_length > 0.0, "Ideal length should be positive");
assert!(
(layout.temperature - 1.0).abs() < 0.001,
"Initial temperature should be 1.0"
);
}
#[test]
fn test_verlet_step_convergence() {
let labels = vec!["A".to_string(), "B".to_string()];
let edges = vec![(0, 1)];
let weights = vec![1];
let mut layout = GraphLayout::new(labels, edges, weights, 200.0, 100.0);
let initial_ax = layout.positions[0].x;
for _ in 0..100 {
layout.step();
}
let final_ax = layout.positions[0].x;
assert!(
layout.positions[0].x >= 0.0 && layout.positions[0].x <= 200.0,
"X should be within bounds"
);
assert!(
layout.positions[0].y >= 0.0 && layout.positions[0].y <= 100.0,
"Y should be within bounds"
);
let moved = (final_ax - initial_ax).abs() > 0.001;
assert!(moved, "Nodes should move");
}
#[test]
fn test_multi_step() {
let labels = vec!["A".to_string(), "B".to_string()];
let edges = vec![(0, 1)];
let weights = vec![1];
let mut layout = GraphLayout::new(labels, edges, weights, 200.0, 100.0);
let initial_ax = layout.positions[0].x;
layout.multi_step(50);
let moved = (layout.positions[0].x - initial_ax).abs() > 0.001;
assert!(moved, "Nodes should move after multi_step");
}
#[test]
fn test_bounds_enforcement() {
let labels = vec!["A".to_string()];
let layout = GraphLayout::new(labels, vec![], vec![], 100.0, 50.0);
let pos = &layout.positions[0];
assert!(pos.x >= 0.0 && pos.x <= 100.0);
assert!(pos.y >= 0.0 && pos.y <= 50.0);
}
#[test]
fn test_drift_color_ranges() {
assert_eq!(drift_color(0), Color::Rgb(166, 227, 161));
assert_eq!(drift_color(30), Color::Rgb(166, 227, 161));
assert_eq!(drift_color(31), Color::Rgb(249, 226, 175));
assert_eq!(drift_color(60), Color::Rgb(249, 226, 175));
assert_eq!(drift_color(61), Color::Rgb(250, 179, 135));
assert_eq!(drift_color(81), Color::Rgb(243, 139, 168));
assert_eq!(drift_color(100), Color::Rgb(243, 139, 168));
}
#[test]
fn test_graph_relation_semantics_are_distinct() {
assert_eq!(
graph_relation_color(GraphRelationSemantic::Hover),
Color::White
);
assert_ne!(
graph_relation_color(GraphRelationSemantic::Neutral),
graph_relation_color(GraphRelationSemantic::Focus)
);
assert_ne!(
graph_relation_color(GraphRelationSemantic::SearchMatch),
graph_relation_color(GraphRelationSemantic::Related)
);
}
#[test]
fn test_cluster_and_edge_semantics_are_stable() {
assert_ne!(
cluster_map_color(ClusterMapSemantic::ThirdParty),
cluster_map_color(ClusterMapSemantic::Central)
);
assert_ne!(
overview_edge_color(OverviewEdgeSemantic::PrimaryBridge),
overview_edge_color(OverviewEdgeSemantic::ExternalSink)
);
}
#[test]
fn test_empty_graph_step() {
let mut layout = GraphLayout::new(vec![], vec![], vec![], 100.0, 100.0);
layout.step(); assert_eq!(layout.positions.len(), 0);
}
#[test]
fn test_update_graph_preserves_positions() {
let labels = vec!["A".to_string(), "B".to_string()];
let edges = vec![(0, 1)];
let weights = vec![1];
let mut layout = GraphLayout::new(labels, edges, weights, 200.0, 100.0);
for _ in 0..10 {
layout.step();
}
let a_pos = layout.positions[0].x;
let new_labels = vec!["A".to_string(), "C".to_string()];
let new_edges = vec![(0, 1)];
let new_weights = vec![1];
layout.update_graph(new_labels, new_edges, new_weights);
assert_eq!(layout.positions.len(), 2);
assert!(
(layout.positions[0].x - a_pos).abs() < 0.01,
"A position should be preserved"
);
}
#[test]
fn test_compute_ideal_length() {
let k = compute_ideal_length(3, 100.0, 100.0);
assert!(k > 15.0, "Ideal length should be reasonable: {k}");
assert!(k < 50.0, "Ideal length should not be too large: {k}");
let k1 = compute_ideal_length(1, 100.0, 100.0);
assert!(k1 > 0.0, "Single node should have positive ideal length");
}
#[test]
fn test_resize_rescales() {
let labels = vec!["A".to_string(), "B".to_string()];
let mut layout = GraphLayout::new(labels, vec![], vec![], 100.0, 100.0);
layout.positions[0].x = 50.0;
layout.positions[0].y = 50.0;
layout.positions[0].prev_x = 50.0;
layout.positions[0].prev_y = 50.0;
layout.resize(200.0, 100.0);
assert!(
(layout.positions[0].x - 100.0).abs() < 0.01,
"X should double when width doubles"
);
assert!(
(layout.positions[0].y - 50.0).abs() < 0.01,
"Y should stay same when height unchanged"
);
}
#[test]
fn test_temperature_decay() {
let labels = vec!["A".to_string(), "B".to_string()];
let mut layout = GraphLayout::new(labels, vec![(0, 1)], vec![1], 200.0, 100.0);
let initial_temp = layout.temperature;
layout.multi_step(100);
assert!(
layout.temperature < initial_temp,
"Temperature should decay over time"
);
assert!(
layout.temperature >= 0.01,
"Temperature should not drop below minimum"
);
}
#[test]
fn test_reheat() {
let labels = vec!["A".to_string(), "B".to_string()];
let mut layout = GraphLayout::new(labels, vec![(0, 1)], vec![1], 200.0, 100.0);
layout.multi_step(200);
let cold_temp = layout.temperature;
layout.reheat();
assert!(
layout.temperature > cold_temp,
"Temperature should increase after reheat"
);
assert!(
(layout.temperature - 1.5).abs() < 0.001,
"Reheat should set temperature to 1.5"
);
}
}