egui_graphs/
graph.rs

1use std::collections::HashSet;
2
3use egui::{Pos2, Rect};
4use petgraph::stable_graph::DefaultIx;
5use petgraph::Directed;
6
7use petgraph::graph::IndexType;
8use petgraph::{
9    stable_graph::{EdgeIndex, EdgeReference, NodeIndex, StableGraph},
10    visit::{EdgeRef, IntoEdgeReferences, IntoNodeReferences},
11    Direction, EdgeType,
12};
13use serde::{Deserialize, Serialize};
14
15use crate::draw::{DisplayEdge, DisplayNode};
16use crate::{
17    default_edge_transform, default_node_transform, to_graph, DefaultEdgeShape, DefaultNodeShape,
18};
19use crate::{metadata::Metadata, Edge, Node};
20
21type StableGraphType<N, E, Ty, Ix, Dn, De> =
22    StableGraph<Node<N, E, Ty, Ix, Dn>, Edge<N, E, Ty, Ix, Dn, De>, Ty, Ix>;
23
24/// Wrapper around [`petgraph::stable_graph::StableGraph`] compatible with [`super::GraphView`].
25/// It is used to store graph data and provide access to it.
26#[derive(Debug, Clone, Serialize, Deserialize)]
27pub struct Graph<
28    N = (),
29    E = (),
30    Ty = Directed,
31    Ix = DefaultIx,
32    Dn = DefaultNodeShape,
33    De = DefaultEdgeShape,
34> where
35    N: Clone,
36    E: Clone,
37    Ty: EdgeType,
38    Ix: IndexType,
39    Dn: DisplayNode<N, E, Ty, Ix>,
40    De: DisplayEdge<N, E, Ty, Ix, Dn>,
41{
42    g: StableGraphType<N, E, Ty, Ix, Dn, De>,
43
44    selected_nodes: Vec<NodeIndex<Ix>>,
45    selected_edges: Vec<EdgeIndex<Ix>>,
46    dragged_node: Option<NodeIndex<Ix>>,
47    hovered_node: Option<NodeIndex<Ix>>,
48
49    bounds: Rect,
50}
51
52impl<N, E, Ty, Ix, Dn, De> From<&StableGraph<N, E, Ty, Ix>> for Graph<N, E, Ty, Ix, Dn, De>
53where
54    N: Clone,
55    E: Clone,
56    Ty: EdgeType,
57    Ix: IndexType,
58    Dn: DisplayNode<N, E, Ty, Ix>,
59    De: DisplayEdge<N, E, Ty, Ix, Dn>,
60{
61    fn from(g: &StableGraph<N, E, Ty, Ix>) -> Self {
62        to_graph(g)
63    }
64}
65
66impl<N, E, Ty, Ix, Dn, De> Graph<N, E, Ty, Ix, Dn, De>
67where
68    N: Clone,
69    E: Clone,
70    Ty: EdgeType,
71    Ix: IndexType,
72    Dn: DisplayNode<N, E, Ty, Ix>,
73    De: DisplayEdge<N, E, Ty, Ix, Dn>,
74{
75    pub fn new(g: StableGraphType<N, E, Ty, Ix, Dn, De>) -> Self {
76        Self {
77            g,
78            selected_nodes: Vec::default(),
79            selected_edges: Vec::default(),
80            dragged_node: Option::default(),
81            hovered_node: Option::default(),
82            bounds: Rect::from_min_max(Pos2::ZERO, Pos2::ZERO),
83        }
84    }
85
86    /// Finds node by position. Can be optimized by using a spatial index like quad-tree if needed.
87    pub fn node_by_screen_pos(&self, meta: &Metadata, screen_pos: Pos2) -> Option<NodeIndex<Ix>> {
88        let pos_in_graph = meta.screen_to_canvas_pos(screen_pos);
89        for (idx, node) in self.nodes_iter() {
90            let display = node.display();
91            if display.is_inside(pos_in_graph) {
92                return Some(idx);
93            }
94        }
95        None
96    }
97
98    /// Finds edge by position.
99    #[allow(clippy::missing_panics_doc)] // TODO: add panics doc
100    pub fn edge_by_screen_pos(&self, meta: &Metadata, screen_pos: Pos2) -> Option<EdgeIndex<Ix>> {
101        let pos_in_graph = meta.screen_to_canvas_pos(screen_pos);
102        for (idx, e) in self.edges_iter() {
103            let Some((idx_start, idx_end)) = self.g.edge_endpoints(e.id()) else {
104                continue;
105            };
106            let start = self.g.node_weight(idx_start).unwrap();
107            let end = self.g.node_weight(idx_end).unwrap();
108            if e.display().is_inside(start, end, pos_in_graph) {
109                return Some(idx);
110            }
111        }
112
113        None
114    }
115
116    pub fn g_mut(&mut self) -> &mut StableGraphType<N, E, Ty, Ix, Dn, De> {
117        &mut self.g
118    }
119
120    pub fn g(&self) -> &StableGraphType<N, E, Ty, Ix, Dn, De> {
121        &self.g
122    }
123
124    /// Adds node to graph setting default location and default label values
125    #[allow(clippy::missing_panics_doc)] // TODO: add panics doc
126    pub fn add_node(&mut self, payload: N) -> NodeIndex<Ix> {
127        self.add_node_custom(payload, default_node_transform)
128    }
129
130    #[allow(clippy::missing_panics_doc)] // TODO: add panics doc
131    pub fn add_node_custom(
132        &mut self,
133        payload: N,
134        node_transform: impl FnOnce(&mut Node<N, E, Ty, Ix, Dn>),
135    ) -> NodeIndex<Ix> {
136        let node = Node::new(payload);
137
138        let idx = self.g.add_node(node);
139        let graph_node = self.g.node_weight_mut(idx).unwrap();
140
141        graph_node.set_id(idx);
142
143        node_transform(graph_node);
144
145        idx
146    }
147
148    /// Adds node to graph setting custom location and default label value
149    #[allow(clippy::missing_panics_doc)] // TODO: add panics doc
150    pub fn add_node_with_location(&mut self, payload: N, location: Pos2) -> NodeIndex<Ix> {
151        self.add_node_custom(payload, |n: &mut Node<N, E, Ty, Ix, Dn>| {
152            n.set_location(location);
153        })
154    }
155
156    /// Adds node to graph setting default location and custom label value
157    pub fn add_node_with_label(&mut self, payload: N, label: String) -> NodeIndex<Ix> {
158        self.add_node_custom(payload, |n: &mut Node<N, E, Ty, Ix, Dn>| {
159            n.set_label(label);
160        })
161    }
162
163    /// Adds node to graph setting custom location and custom label value
164    #[allow(clippy::missing_panics_doc)] // TODO: add panics doc
165    pub fn add_node_with_label_and_location(
166        &mut self,
167        payload: N,
168        label: String,
169        location: Pos2,
170    ) -> NodeIndex<Ix> {
171        self.add_node_custom(payload, |n: &mut Node<N, E, Ty, Ix, Dn>| {
172            n.set_location(location);
173            n.set_label(label);
174        })
175    }
176
177    /// Removes node by index. Returns removed node and None if it does not exist.
178    pub fn remove_node(&mut self, idx: NodeIndex<Ix>) -> Option<Node<N, E, Ty, Ix, Dn>> {
179        // before removing nodes we need to remove all edges connected to it
180        let neighbors = self.g.neighbors_undirected(idx).collect::<Vec<_>>();
181        for n in &neighbors {
182            self.remove_edges_between(idx, *n);
183            self.remove_edges_between(*n, idx);
184        }
185
186        self.g.remove_node(idx)
187    }
188
189    /// Removes all edges between start and end node. Returns removed edges count.
190    #[allow(clippy::missing_panics_doc)] // TODO: add panics doc
191    pub fn remove_edges_between(&mut self, start: NodeIndex<Ix>, end: NodeIndex<Ix>) -> usize {
192        let idxs = self
193            .g
194            .edges_connecting(start, end)
195            .map(|e| e.id())
196            .collect::<Vec<_>>();
197        if idxs.is_empty() {
198            return 0;
199        }
200
201        let mut removed = 0;
202        for e in &idxs {
203            self.g.remove_edge(*e).unwrap();
204            removed += 1;
205        }
206
207        removed
208    }
209
210    /// Adds edge between start and end node with default label.
211    #[allow(clippy::missing_panics_doc)] // TODO: add panics doc
212    pub fn add_edge(
213        &mut self,
214        start: NodeIndex<Ix>,
215        end: NodeIndex<Ix>,
216        payload: E,
217    ) -> EdgeIndex<Ix> {
218        self.add_edge_custom(start, end, payload, default_edge_transform)
219    }
220
221    /// Adds edge between start and end node with custom label setting correct order.
222    #[allow(clippy::missing_panics_doc)] // TODO: add panics doc
223    pub fn add_edge_with_label(
224        &mut self,
225        start: NodeIndex<Ix>,
226        end: NodeIndex<Ix>,
227        payload: E,
228        label: String,
229    ) -> EdgeIndex<Ix> {
230        self.add_edge_custom(start, end, payload, |e: &mut Edge<N, E, Ty, Ix, Dn, De>| {
231            e.set_label(label);
232        })
233    }
234
235    #[allow(clippy::missing_panics_doc)] // TODO: add panics doc
236    pub fn add_edge_custom(
237        &mut self,
238        start: NodeIndex<Ix>,
239        end: NodeIndex<Ix>,
240        payload: E,
241        edge_transform: impl FnOnce(&mut Edge<N, E, Ty, Ix, Dn, De>),
242    ) -> EdgeIndex<Ix> {
243        // Choose the smallest non-negative order not yet used by edges in the SAME direction
244        // to avoid multiple edges sharing the same visual offset (stacking).
245        let used_orders: std::collections::HashSet<usize> = self
246            .g
247            .edges_connecting(start, end)
248            .map(|e| e.weight().order())
249            .collect();
250        let mut order = 0usize;
251        while used_orders.contains(&order) {
252            order += 1;
253        }
254
255        let idx = self.g.add_edge(start, end, Edge::new(payload));
256        let e = self.g.edge_weight_mut(idx).unwrap();
257
258        e.set_id(idx);
259        e.set_order(order);
260
261        edge_transform(e);
262
263        // If we have two opposite-direction edges with order 0 (two straight lines),
264        // bump all siblings' order by 1 to avoid overlapping straight segments.
265
266        let siblings_ids: Vec<_> = {
267            let mut visited = HashSet::new();
268            self.g
269                .edges_connecting(start, end)
270                .chain(self.g.edges_connecting(end, start))
271                .filter(|e| visited.insert(e.id()))
272                .map(|e| e.id())
273                .collect()
274        };
275
276        let mut had_zero = false;
277        let mut increase_order = false;
278        for id in &siblings_ids {
279            if let Some(edge) = self.g.edge_weight_mut(*id) {
280                if edge.order() == 0 {
281                    if had_zero {
282                        increase_order = true;
283                        break;
284                    }
285
286                    had_zero = true;
287                }
288            }
289        }
290
291        if increase_order {
292            for id in siblings_ids {
293                if let Some(edge) = self.g.edge_weight_mut(id) {
294                    edge.set_order(edge.order() + 1);
295                }
296            }
297        }
298
299        idx
300    }
301
302    /// Removes edge by index and updates order of the siblings.
303    /// Returns removed edge and None if it does not exist.
304    pub fn remove_edge(&mut self, idx: EdgeIndex<Ix>) -> Option<Edge<N, E, Ty, Ix, Dn, De>> {
305        let (start, end) = self.g.edge_endpoints(idx)?;
306        let order = self.g.edge_weight(idx)?.order();
307
308        let payload = self.g.remove_edge(idx)?;
309
310        let siblings = self
311            .g
312            .edges_connecting(start, end)
313            .map(|edge_ref| edge_ref.id())
314            .collect::<Vec<_>>();
315
316        // update order of siblings
317        for s_idx in &siblings {
318            let sibling_order = self.g.edge_weight(*s_idx)?.order();
319            if sibling_order < order {
320                continue;
321            }
322            self.g.edge_weight_mut(*s_idx)?.set_order(sibling_order - 1);
323        }
324
325        Some(payload)
326    }
327
328    /// Returns iterator over all edges connecting start and end node.
329    #[allow(clippy::type_complexity)]
330    pub fn edges_connecting(
331        &self,
332        start: NodeIndex<Ix>,
333        end: NodeIndex<Ix>,
334    ) -> impl Iterator<Item = (EdgeIndex<Ix>, &Edge<N, E, Ty, Ix, Dn, De>)> {
335        self.g
336            .edges_connecting(start, end)
337            .map(|e| (e.id(), e.weight()))
338    }
339
340    /// Provides iterator over all nodes and their indices.
341    pub fn nodes_iter(&self) -> impl Iterator<Item = (NodeIndex<Ix>, &Node<N, E, Ty, Ix, Dn>)> {
342        self.g.node_references()
343    }
344
345    /// Provides iterator over all edges and their indices.
346    #[allow(clippy::type_complexity)]
347    pub fn edges_iter(&self) -> impl Iterator<Item = (EdgeIndex<Ix>, &Edge<N, E, Ty, Ix, Dn, De>)> {
348        self.g.edge_references().map(|e| (e.id(), e.weight()))
349    }
350
351    pub fn node(&self, i: NodeIndex<Ix>) -> Option<&Node<N, E, Ty, Ix, Dn>> {
352        self.g.node_weight(i)
353    }
354
355    pub fn edge(&self, i: EdgeIndex<Ix>) -> Option<&Edge<N, E, Ty, Ix, Dn, De>> {
356        self.g.edge_weight(i)
357    }
358
359    pub fn edge_endpoints(&self, i: EdgeIndex<Ix>) -> Option<(NodeIndex<Ix>, NodeIndex<Ix>)> {
360        self.g.edge_endpoints(i)
361    }
362
363    pub fn node_mut(&mut self, i: NodeIndex<Ix>) -> Option<&mut Node<N, E, Ty, Ix, Dn>> {
364        self.g.node_weight_mut(i)
365    }
366
367    pub fn edge_mut(&mut self, i: EdgeIndex<Ix>) -> Option<&mut Edge<N, E, Ty, Ix, Dn, De>> {
368        self.g.edge_weight_mut(i)
369    }
370
371    pub fn is_directed(&self) -> bool {
372        self.g.is_directed()
373    }
374
375    pub fn edges_num(&self, idx: NodeIndex<Ix>) -> usize {
376        self.g.edges(idx).count()
377    }
378
379    pub fn edges_directed(
380        &self,
381        idx: NodeIndex<Ix>,
382        dir: Direction,
383    ) -> impl Iterator<Item = EdgeReference<'_, Edge<N, E, Ty, Ix, Dn, De>, Ix>> {
384        self.g.edges_directed(idx, dir)
385    }
386
387    pub fn selected_nodes(&self) -> &[NodeIndex<Ix>] {
388        &self.selected_nodes
389    }
390
391    pub fn set_selected_nodes(&mut self, nodes: Vec<NodeIndex<Ix>>) {
392        self.selected_nodes = nodes;
393    }
394
395    pub fn selected_edges(&self) -> &[EdgeIndex<Ix>] {
396        &self.selected_edges
397    }
398
399    pub fn set_selected_edges(&mut self, edges: Vec<EdgeIndex<Ix>>) {
400        self.selected_edges = edges;
401    }
402
403    pub fn dragged_node(&self) -> Option<NodeIndex<Ix>> {
404        self.dragged_node
405    }
406
407    pub fn set_dragged_node(&mut self, node: Option<NodeIndex<Ix>>) {
408        self.dragged_node = node;
409    }
410
411    pub fn hovered_node(&self) -> Option<NodeIndex<Ix>> {
412        self.hovered_node
413    }
414
415    pub fn set_hovered_node(&mut self, node: Option<NodeIndex<Ix>>) {
416        self.hovered_node = node;
417    }
418
419    pub fn edge_count(&self) -> usize {
420        self.g.edge_count()
421    }
422
423    pub fn node_count(&self) -> usize {
424        self.g.node_count()
425    }
426
427    pub fn set_bounds(&mut self, bounds: Rect) {
428        self.bounds = bounds;
429    }
430
431    pub fn bounds(&self) -> Rect {
432        self.bounds
433    }
434}
435
436#[cfg(test)]
437mod tests {
438    use super::*;
439    use petgraph::stable_graph::StableGraph;
440
441    #[test]
442    fn edge_orders_do_not_duplicate_in_same_direction() {
443        // Directed graph with default display types
444        let mut sg: StableGraph<(), ()> = StableGraph::default();
445        let a = sg.add_node(());
446        let b = sg.add_node(());
447        let mut g: Graph<(), (), Directed> =
448            Graph::new(sg.map(|_, ()| crate::Node::new(()), |_, ()| crate::Edge::new(())));
449
450        // Add opposite-direction edges; both initially 0, then logic bumps them to 1.
451        let e1 = g.add_edge(a, b, ());
452        let e2 = g.add_edge(b, a, ());
453        let o1 = g.edge(e1).unwrap().order();
454        let o2 = g.edge(e2).unwrap().order();
455        assert_eq!(
456            o1, 1,
457            "A->B should be bumped to order 1 when B->A exists at 0"
458        );
459        assert_eq!(
460            o2, 1,
461            "B->A should be bumped to order 1 when A->B exists at 0"
462        );
463
464        // Now add a second A->B edge; it should pick smallest unused (0), not duplicate 1.
465        let e3 = g.add_edge(a, b, ());
466        let o3 = g.edge(e3).unwrap().order();
467        assert_eq!(
468            o3, 0,
469            "Second A->B edge should get order 0 (smallest unused), not stack at 1"
470        );
471
472        // Add third A->B; orders used are {0,1}, expect 2.
473        let e4 = g.add_edge(a, b, ());
474        let o4 = g.edge(e4).unwrap().order();
475        assert_eq!(o4, 2, "Third A->B edge should get order 2");
476    }
477}