use std::collections::{BTreeSet, HashMap, VecDeque};
use eframe::egui::{self, Align2, Color32, FontId, Pos2, Sense, Stroke, Vec2};
use eframe::egui::style::ScrollAnimation;
pub const BOX_W: f32 = 184.0;
pub const BOX_H: f32 = 34.0;
#[derive(Clone)]
pub struct GraphNode {
pub id: String,
pub label: String,
pub fill: Color32,
pub stroke: Color32,
pub pos: Pos2,
}
#[derive(Clone)]
pub struct GraphEdge {
pub from: String,
pub to: String,
pub color: Color32,
pub dashed: bool,
pub label: Option<String>,
}
#[derive(Clone, Default)]
pub struct GraphModel {
pub nodes: Vec<GraphNode>,
pub edges: Vec<GraphEdge>,
}
#[derive(Clone, Default)]
pub struct NodeDecoration {
pub ring: Option<Color32>,
pub badge: Option<String>,
pub badge_color: Option<Color32>,
}
#[derive(Clone, Default)]
pub struct Decorations {
pub nodes: HashMap<String, NodeDecoration>,
pub edges: Vec<GraphEdge>,
}
#[derive(Clone)]
pub struct GraphView {
pub pan: Vec2,
pub zoom: f32,
pub selected: Option<String>,
}
impl Default for GraphView {
fn default() -> Self {
Self { pan: Vec2::ZERO, zoom: 1.0, selected: None }
}
}
impl GraphView {
pub fn fit(&mut self) {
self.pan = Vec2::ZERO;
self.zoom = 1.0;
}
pub fn clear_selection(&mut self) {
self.selected = None;
}
}
pub fn downstream_of(model: &GraphModel, seed: &str) -> BTreeSet<String> {
let mut adj: HashMap<&str, Vec<&str>> = HashMap::new();
for e in &model.edges {
adj.entry(e.from.as_str()).or_default().push(e.to.as_str());
}
let mut lit: BTreeSet<String> = BTreeSet::new();
let mut q: VecDeque<&str> = VecDeque::new();
lit.insert(seed.to_string());
q.push_back(seed);
while let Some(cur) = q.pop_front() {
if let Some(outs) = adj.get(cur) {
for &nxt in outs {
if lit.insert(nxt.to_string()) {
q.push_back(nxt);
}
}
}
}
lit
}
fn dim(c: Color32) -> Color32 {
c.linear_multiply(0.22)
}
#[derive(Default)]
pub struct GraphResponse {
pub clicked_node: Option<String>,
pub clicked_empty: bool,
}
#[allow(clippy::too_many_arguments)]
pub fn draw_graph(
ui: &mut egui::Ui,
model: &GraphModel,
decorations: &Decorations,
view: &mut GraphView,
bg: Color32,
text: Color32,
selection_ring: Color32,
text_dim: Color32,
) -> GraphResponse {
if view.zoom <= 0.0 {
view.zoom = 1.0;
}
let (resp, painter) = ui.allocate_painter(ui.available_size(), Sense::click_and_drag());
painter.rect_filled(resp.rect, 4.0, bg);
if resp.dragged() {
view.pan += resp.drag_delta();
}
if resp.hovered() {
let scroll = ui.input(|i| i.smooth_scroll_delta.y);
if scroll != 0.0 {
view.zoom = (view.zoom * (1.0 + scroll * 0.001)).clamp(0.25, 4.0);
}
}
let origin = resp.rect.center() + view.pan;
let click = resp.clicked().then(|| resp.interact_pointer_pos()).flatten();
paint_graph(
&painter, model, decorations, view, origin, click, text, selection_ring, text_dim,
)
}
pub fn content_extent(model: &GraphModel) -> Vec2 {
const MARGIN: f32 = 48.0;
if model.nodes.is_empty() {
return Vec2::new(BOX_W, BOX_H) + Vec2::splat(2.0 * MARGIN);
}
let (mut min, mut max) = (Pos2::new(f32::MAX, f32::MAX), Pos2::new(f32::MIN, f32::MIN));
let half = Vec2::new(BOX_W, BOX_H) * 0.5;
for n in &model.nodes {
min.x = min.x.min(n.pos.x - half.x);
min.y = min.y.min(n.pos.y - half.y);
max.x = max.x.max(n.pos.x + half.x);
max.y = max.y.max(n.pos.y + half.y);
}
(max - min) + Vec2::splat(2.0 * MARGIN)
}
#[allow(clippy::too_many_arguments)]
pub fn draw_graph_scrolled(
ui: &mut egui::Ui,
model: &GraphModel,
decorations: &Decorations,
view: &mut GraphView,
bg: Color32,
text: Color32,
selection_ring: Color32,
text_dim: Color32,
) -> GraphResponse {
if view.zoom <= 0.0 {
view.zoom = 1.0;
}
ui.style_mut().scroll_animation = ScrollAnimation::new(1200.0, egui::Rangef::new(0.08, 0.30));
{
let s = &mut ui.style_mut().spacing.scroll;
s.bar_width = 12.0;
s.handle_min_length = 24.0;
s.floating = false;
s.dormant_background_opacity = 1.0;
s.active_background_opacity = 1.0;
s.interact_background_opacity = 1.0;
}
let extent = content_extent(model) * view.zoom;
let mut out = GraphResponse::default();
egui::ScrollArea::both()
.id_salt("graph_scroll")
.auto_shrink([false, false])
.scroll_bar_visibility(egui::scroll_area::ScrollBarVisibility::AlwaysVisible)
.show(ui, |ui| {
let (resp, painter) =
ui.allocate_painter(extent.max(ui.available_size()), Sense::click_and_drag());
painter.rect_filled(resp.rect, 4.0, bg);
if resp.hovered() {
let (scroll, zoom_mod) =
ui.input(|i| (i.smooth_scroll_delta.y, i.modifiers.command || i.modifiers.ctrl));
if zoom_mod && scroll != 0.0 {
view.zoom = (view.zoom * (1.0 + scroll * 0.001)).clamp(0.25, 4.0);
}
}
let origin = resp.rect.center();
let click = resp.clicked().then(|| resp.interact_pointer_pos()).flatten();
out = paint_graph(
&painter, model, decorations, view, origin, click, text, selection_ring, text_dim,
);
});
out
}
#[allow(clippy::too_many_arguments)]
fn paint_graph(
painter: &egui::Painter,
model: &GraphModel,
decorations: &Decorations,
view: &mut GraphView,
origin: Pos2,
click: Option<Pos2>,
text: Color32,
selection_ring: Color32,
text_dim: Color32,
) -> GraphResponse {
let mut out = GraphResponse::default();
let zoom = view.zoom;
let project = |p: Pos2| origin + p.to_vec2() * zoom;
let idx: HashMap<&str, usize> =
model.nodes.iter().enumerate().map(|(i, n)| (n.id.as_str(), i)).collect();
if let Some(click) = click {
let hit = model.nodes.iter().find_map(|n| {
let c = project(n.pos);
let half = Vec2::new(BOX_W, BOX_H) * 0.5 * zoom;
let rect = egui::Rect::from_center_size(c, half * 2.0);
rect.contains(click).then(|| n.id.clone())
});
match hit {
Some(id) => {
view.selected = Some(id.clone());
out.clicked_node = Some(id);
}
None => {
view.selected = None;
out.clicked_empty = true;
}
}
}
let lit: BTreeSet<String> = view
.selected
.as_deref()
.map(|s| downstream_of(model, s))
.unwrap_or_default();
let highlighting = view.selected.is_some();
for e in &model.edges {
draw_edge(painter, model, &idx, &project, e, zoom, highlighting, &lit, text_dim, false);
}
for e in &decorations.edges {
draw_edge(painter, model, &idx, &project, e, zoom, false, &lit, text_dim, true);
}
for n in &model.nodes {
let c = project(n.pos);
let on_trace = highlighting && lit.contains(&n.id);
let (fill, stroke) = if highlighting && !on_trace {
(dim(n.fill), dim(n.stroke))
} else {
(n.fill, n.stroke)
};
let size = Vec2::new(BOX_W, BOX_H) * zoom;
let rect = egui::Rect::from_center_size(c, size);
painter.rect_filled(rect, 5.0 * zoom, fill);
let deco = decorations.nodes.get(&n.id);
let ring = if view.selected.as_deref() == Some(n.id.as_str()) {
Stroke::new(3.0, selection_ring)
} else if let Some(rc) = deco.and_then(|d| d.ring) {
let rc = if highlighting && !on_trace { dim(rc) } else { rc };
Stroke::new(2.4, rc)
} else {
Stroke::new(1.4, stroke)
};
painter.rect_stroke(rect, 5.0 * zoom, ring, egui::epaint::StrokeKind::Outside);
if zoom > 0.45 {
painter.text(
c,
Align2::CENTER_CENTER,
&n.label,
FontId::proportional(11.0 * zoom.clamp(0.7, 1.4)),
text,
);
if let Some(badge) = deco.and_then(|d| d.badge.as_deref()) {
let col = deco.and_then(|d| d.badge_color).unwrap_or(text);
let col = if highlighting && !on_trace { dim(col) } else { col };
painter.text(
rect.right_top() + Vec2::new(-2.0, 1.0),
Align2::RIGHT_TOP,
badge,
FontId::proportional(12.0 * zoom.clamp(0.7, 1.4)),
col,
);
}
}
}
out
}
#[allow(clippy::too_many_arguments)]
fn draw_edge(
painter: &egui::Painter,
model: &GraphModel,
idx: &HashMap<&str, usize>,
project: &impl Fn(Pos2) -> Pos2,
e: &GraphEdge,
zoom: f32,
highlighting: bool,
lit: &BTreeSet<String>,
text_dim: Color32,
emphasise: bool,
) {
let (Some(&fi), Some(&ti)) = (idx.get(e.from.as_str()), idx.get(e.to.as_str())) else {
return;
};
let on_trace = highlighting && lit.contains(&e.from) && lit.contains(&e.to);
let (pa, pb) = (project(model.nodes[fi].pos), project(model.nodes[ti].pos));
let a = pa + Vec2::new(BOX_W * 0.5 * zoom, 0.0);
let b = pb - Vec2::new(BOX_W * 0.5 * zoom, 0.0);
let color = if !emphasise && highlighting && !on_trace { dim(e.color) } else { e.color };
let w = if emphasise { 2.6 } else if on_trace { 2.4 } else { 1.4 };
let mid = Pos2::new((a.x + b.x) * 0.5, a.y);
let mid2 = Pos2::new((a.x + b.x) * 0.5, b.y);
let curve = egui::epaint::CubicBezierShape::from_points_stroke(
[a, mid, mid2, b],
false,
Color32::TRANSPARENT,
Stroke::new(w, color),
);
if e.dashed {
let pts: Vec<Pos2> = (0..=20).map(|i| curve.sample(i as f32 / 20.0)).collect();
for (i, win) in pts.windows(2).enumerate() {
if i % 2 == 0 {
painter.line_segment([win[0], win[1]], Stroke::new(w, color));
}
}
} else {
painter.add(egui::Shape::CubicBezier(curve));
}
let dir = (b - mid2).normalized();
let perp = Vec2::new(-dir.y, dir.x);
let head = 6.0 * zoom.clamp(0.6, 1.6);
painter.line_segment([b, b - dir * head + perp * head * 0.5], Stroke::new(w, color));
painter.line_segment([b, b - dir * head - perp * head * 0.5], Stroke::new(w, color));
if let Some(lbl) = &e.label {
if zoom > 0.45 && !lbl.is_empty() {
let mp = Pos2::new((a.x + b.x) * 0.5, (a.y + b.y) * 0.5 - 6.0 * zoom);
painter.text(
mp,
Align2::CENTER_BOTTOM,
lbl,
FontId::proportional(10.0 * zoom.clamp(0.7, 1.3)),
if emphasise { color } else { text_dim },
);
}
}
}
#[cfg(test)]
mod tests {
use super::*;
fn node(id: &str, x: f32) -> GraphNode {
GraphNode {
id: id.into(),
label: id.into(),
fill: Color32::GRAY,
stroke: Color32::WHITE,
pos: Pos2::new(x, 0.0),
}
}
fn edge(from: &str, to: &str) -> GraphEdge {
GraphEdge { from: from.into(), to: to.into(), color: Color32::WHITE, dashed: false, label: None }
}
#[test]
fn downstream_is_bfs_closure_from_seed() {
let model = GraphModel {
nodes: vec![node("A", 0.0), node("B", 100.0), node("C", 200.0)],
edges: vec![edge("A", "B"), edge("B", "C"), edge("A", "C")],
};
let from_a = downstream_of(&model, "A");
assert!(from_a.contains("A") && from_a.contains("B") && from_a.contains("C"));
let from_b = downstream_of(&model, "B");
assert!(from_b.contains("B") && from_b.contains("C"));
assert!(!from_b.contains("A"), "BFS is forward-only β A is not downstream of B");
}
#[test]
fn content_extent_is_the_true_bounding_box_not_the_viewport() {
const MARGIN: f32 = 48.0;
let empty = GraphModel::default();
let e0 = content_extent(&empty);
assert_eq!(e0, Vec2::new(BOX_W, BOX_H) + Vec2::splat(2.0 * MARGIN));
let two = GraphModel {
nodes: vec![node("A", 0.0), node("B", 300.0)],
edges: vec![],
};
let e = content_extent(&two);
assert_eq!(e.x, 300.0 + BOX_W + 2.0 * MARGIN, "width spans both chips + margins");
assert_eq!(e.y, BOX_H + 2.0 * MARGIN, "height is one chip row + margins");
let wider = GraphModel {
nodes: vec![node("A", 0.0), node("B", 900.0)],
edges: vec![],
};
assert!(content_extent(&wider).x > e.x, "a wider graph reports a wider extent");
}
#[test]
fn view_fit_and_clear_reset_state() {
let mut v = GraphView { pan: Vec2::new(5.0, 5.0), zoom: 2.0, selected: Some("A".into()) };
v.fit();
assert_eq!(v.pan, Vec2::ZERO);
assert_eq!(v.zoom, 1.0);
assert_eq!(v.selected.as_deref(), Some("A"));
v.clear_selection();
assert!(v.selected.is_none());
}
}