Skip to main content

proof_engine/graph/
graph_core.rs

1use glam::Vec2;
2use std::collections::{HashMap, HashSet, VecDeque};
3
4#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash, PartialOrd, Ord)]
5pub struct NodeId(pub u32);
6
7#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash, PartialOrd, Ord)]
8pub struct EdgeId(pub u32);
9
10#[derive(Debug, Clone, Copy, PartialEq, Eq)]
11pub enum GraphKind {
12    Directed,
13    Undirected,
14}
15
16#[derive(Debug, Clone)]
17pub struct NodeData<N> {
18    pub id: NodeId,
19    pub data: N,
20    pub position: Vec2,
21    pub weight: f32,
22}
23
24#[derive(Debug, Clone)]
25pub struct EdgeData<E> {
26    pub id: EdgeId,
27    pub from: NodeId,
28    pub to: NodeId,
29    pub data: E,
30    pub weight: f32,
31}
32
33#[derive(Debug, Clone)]
34pub struct Graph<N, E> {
35    pub kind: GraphKind,
36    nodes: HashMap<NodeId, NodeData<N>>,
37    edges: HashMap<EdgeId, EdgeData<E>>,
38    adjacency: HashMap<NodeId, Vec<(NodeId, EdgeId)>>,
39    next_node_id: u32,
40    next_edge_id: u32,
41}
42
43impl<N, E> Graph<N, E> {
44    pub fn new(kind: GraphKind) -> Self {
45        Self {
46            kind,
47            nodes: HashMap::new(),
48            edges: HashMap::new(),
49            adjacency: HashMap::new(),
50            next_node_id: 0,
51            next_edge_id: 0,
52        }
53    }
54
55    pub fn add_node(&mut self, data: N) -> NodeId {
56        self.add_node_with_pos(data, Vec2::ZERO)
57    }
58
59    pub fn add_node_with_pos(&mut self, data: N, position: Vec2) -> NodeId {
60        let id = NodeId(self.next_node_id);
61        self.next_node_id += 1;
62        self.nodes.insert(id, NodeData { id, data, position, weight: 1.0 });
63        self.adjacency.insert(id, Vec::new());
64        id
65    }
66
67    pub fn add_edge(&mut self, from: NodeId, to: NodeId, data: E) -> EdgeId {
68        self.add_edge_weighted(from, to, data, 1.0)
69    }
70
71    pub fn add_edge_weighted(&mut self, from: NodeId, to: NodeId, data: E, weight: f32) -> EdgeId {
72        let id = EdgeId(self.next_edge_id);
73        self.next_edge_id += 1;
74        self.edges.insert(id, EdgeData { id, from, to, data, weight });
75        if let Some(adj) = self.adjacency.get_mut(&from) {
76            adj.push((to, id));
77        }
78        if self.kind == GraphKind::Undirected {
79            if let Some(adj) = self.adjacency.get_mut(&to) {
80                adj.push((from, id));
81            }
82        }
83        id
84    }
85
86    pub fn remove_node(&mut self, id: NodeId) {
87        self.nodes.remove(&id);
88        self.adjacency.remove(&id);
89        // Remove all edges referencing this node
90        let edge_ids: Vec<EdgeId> = self.edges.iter()
91            .filter(|(_, e)| e.from == id || e.to == id)
92            .map(|(eid, _)| *eid)
93            .collect();
94        for eid in &edge_ids {
95            self.edges.remove(eid);
96        }
97        // Clean adjacency lists
98        for (_, adj) in self.adjacency.iter_mut() {
99            adj.retain(|(nid, eid)| *nid != id && !edge_ids.contains(eid));
100        }
101    }
102
103    pub fn remove_edge(&mut self, id: EdgeId) {
104        if let Some(edge) = self.edges.remove(&id) {
105            if let Some(adj) = self.adjacency.get_mut(&edge.from) {
106                adj.retain(|(_, eid)| *eid != id);
107            }
108            if self.kind == GraphKind::Undirected {
109                if let Some(adj) = self.adjacency.get_mut(&edge.to) {
110                    adj.retain(|(_, eid)| *eid != id);
111                }
112            }
113        }
114    }
115
116    pub fn neighbors(&self, id: NodeId) -> Vec<NodeId> {
117        self.adjacency.get(&id)
118            .map(|adj| adj.iter().map(|(nid, _)| *nid).collect())
119            .unwrap_or_default()
120    }
121
122    pub fn neighbor_edges(&self, id: NodeId) -> Vec<(NodeId, EdgeId)> {
123        self.adjacency.get(&id).cloned().unwrap_or_default()
124    }
125
126    pub fn degree(&self, id: NodeId) -> usize {
127        self.adjacency.get(&id).map(|a| a.len()).unwrap_or(0)
128    }
129
130    pub fn node_count(&self) -> usize {
131        self.nodes.len()
132    }
133
134    pub fn edge_count(&self) -> usize {
135        self.edges.len()
136    }
137
138    pub fn has_node(&self, id: NodeId) -> bool {
139        self.nodes.contains_key(&id)
140    }
141
142    pub fn has_edge(&self, id: EdgeId) -> bool {
143        self.edges.contains_key(&id)
144    }
145
146    pub fn get_node(&self, id: NodeId) -> Option<&NodeData<N>> {
147        self.nodes.get(&id)
148    }
149
150    pub fn get_node_mut(&mut self, id: NodeId) -> Option<&mut NodeData<N>> {
151        self.nodes.get_mut(&id)
152    }
153
154    pub fn get_edge(&self, id: EdgeId) -> Option<&EdgeData<E>> {
155        self.edges.get(&id)
156    }
157
158    pub fn get_edge_mut(&mut self, id: EdgeId) -> Option<&mut EdgeData<E>> {
159        self.edges.get_mut(&id)
160    }
161
162    pub fn find_edge(&self, from: NodeId, to: NodeId) -> Option<EdgeId> {
163        self.adjacency.get(&from)
164            .and_then(|adj| adj.iter().find(|(nid, _)| *nid == to).map(|(_, eid)| *eid))
165    }
166
167    pub fn edge_weight(&self, id: EdgeId) -> f32 {
168        self.edges.get(&id).map(|e| e.weight).unwrap_or(f32::INFINITY)
169    }
170
171    pub fn set_node_position(&mut self, id: NodeId, pos: Vec2) {
172        if let Some(node) = self.nodes.get_mut(&id) {
173            node.position = pos;
174        }
175    }
176
177    pub fn node_position(&self, id: NodeId) -> Vec2 {
178        self.nodes.get(&id).map(|n| n.position).unwrap_or(Vec2::ZERO)
179    }
180
181    // Iterators
182    pub fn nodes(&self) -> impl Iterator<Item = &NodeData<N>> {
183        self.nodes.values()
184    }
185
186    pub fn node_ids(&self) -> Vec<NodeId> {
187        let mut ids: Vec<NodeId> = self.nodes.keys().copied().collect();
188        ids.sort();
189        ids
190    }
191
192    pub fn edges(&self) -> impl Iterator<Item = &EdgeData<E>> {
193        self.edges.values()
194    }
195
196    pub fn edge_ids(&self) -> Vec<EdgeId> {
197        let mut ids: Vec<EdgeId> = self.edges.keys().copied().collect();
198        ids.sort();
199        ids
200    }
201
202    pub fn bfs(&self, start: NodeId) -> BfsIterator<N, E> {
203        let mut queue = VecDeque::new();
204        let mut visited = HashSet::new();
205        if self.has_node(start) {
206            queue.push_back(start);
207            visited.insert(start);
208        }
209        BfsIterator { graph: self, queue, visited }
210    }
211
212    pub fn dfs(&self, start: NodeId) -> DfsIterator<N, E> {
213        let mut stack = Vec::new();
214        let mut visited = HashSet::new();
215        if self.has_node(start) {
216            stack.push(start);
217            visited.insert(start);
218        }
219        DfsIterator { graph: self, stack, visited }
220    }
221
222    /// Extract a subgraph containing only the given node IDs
223    pub fn subgraph(&self, node_ids: &[NodeId]) -> Graph<N, E>
224    where
225        N: Clone,
226        E: Clone,
227    {
228        let set: HashSet<NodeId> = node_ids.iter().copied().collect();
229        let mut sub = Graph::new(self.kind);
230        // We need to preserve node IDs, so we manipulate internals
231        for &nid in node_ids {
232            if let Some(nd) = self.nodes.get(&nid) {
233                sub.nodes.insert(nid, NodeData {
234                    id: nid,
235                    data: nd.data.clone(),
236                    position: nd.position,
237                    weight: nd.weight,
238                });
239                sub.adjacency.insert(nid, Vec::new());
240            }
241        }
242        sub.next_node_id = self.next_node_id;
243        sub.next_edge_id = self.next_edge_id;
244        for (eid, ed) in &self.edges {
245            if set.contains(&ed.from) && set.contains(&ed.to) {
246                sub.edges.insert(*eid, EdgeData {
247                    id: *eid,
248                    from: ed.from,
249                    to: ed.to,
250                    data: ed.data.clone(),
251                    weight: ed.weight,
252                });
253                if let Some(adj) = sub.adjacency.get_mut(&ed.from) {
254                    adj.push((ed.to, *eid));
255                }
256                if self.kind == GraphKind::Undirected {
257                    if let Some(adj) = sub.adjacency.get_mut(&ed.to) {
258                        adj.push((ed.from, *eid));
259                    }
260                }
261            }
262        }
263        sub
264    }
265
266    /// Union of two graphs (combines all nodes and edges)
267    pub fn union(&self, other: &Graph<N, E>) -> Graph<N, E>
268    where
269        N: Clone,
270        E: Clone,
271    {
272        let mut result = self.clone();
273        let node_offset = result.next_node_id;
274        let edge_offset = result.next_edge_id;
275        let mut node_map: HashMap<NodeId, NodeId> = HashMap::new();
276        for nd in other.nodes.values() {
277            let new_id = NodeId(nd.id.0 + node_offset);
278            node_map.insert(nd.id, new_id);
279            result.nodes.insert(new_id, NodeData {
280                id: new_id,
281                data: nd.data.clone(),
282                position: nd.position,
283                weight: nd.weight,
284            });
285            result.adjacency.insert(new_id, Vec::new());
286        }
287        result.next_node_id = node_offset + other.next_node_id;
288        for ed in other.edges.values() {
289            let new_eid = EdgeId(ed.id.0 + edge_offset);
290            let new_from = node_map[&ed.from];
291            let new_to = node_map[&ed.to];
292            result.edges.insert(new_eid, EdgeData {
293                id: new_eid,
294                from: new_from,
295                to: new_to,
296                data: ed.data.clone(),
297                weight: ed.weight,
298            });
299            if let Some(adj) = result.adjacency.get_mut(&new_from) {
300                adj.push((new_to, new_eid));
301            }
302            if result.kind == GraphKind::Undirected {
303                if let Some(adj) = result.adjacency.get_mut(&new_to) {
304                    adj.push((new_from, new_eid));
305                }
306            }
307        }
308        result.next_edge_id = edge_offset + other.next_edge_id;
309        result
310    }
311
312    /// Complement graph: has edges where original doesn't, and vice versa
313    pub fn complement(&self) -> Graph<N, E>
314    where
315        N: Clone,
316        E: Default + Clone,
317    {
318        let mut result = Graph::new(self.kind);
319        for nd in self.nodes.values() {
320            result.nodes.insert(nd.id, NodeData {
321                id: nd.id,
322                data: nd.data.clone(),
323                position: nd.position,
324                weight: nd.weight,
325            });
326            result.adjacency.insert(nd.id, Vec::new());
327        }
328        result.next_node_id = self.next_node_id;
329
330        let node_ids = self.node_ids();
331        for i in 0..node_ids.len() {
332            for j in (i + 1)..node_ids.len() {
333                let a = node_ids[i];
334                let b = node_ids[j];
335                let has_edge = self.find_edge(a, b).is_some()
336                    || (self.kind == GraphKind::Undirected && self.find_edge(b, a).is_some());
337                if !has_edge {
338                    let eid = EdgeId(result.next_edge_id);
339                    result.next_edge_id += 1;
340                    result.edges.insert(eid, EdgeData {
341                        id: eid, from: a, to: b, data: E::default(), weight: 1.0,
342                    });
343                    if let Some(adj) = result.adjacency.get_mut(&a) {
344                        adj.push((b, eid));
345                    }
346                    if result.kind == GraphKind::Undirected {
347                        if let Some(adj) = result.adjacency.get_mut(&b) {
348                            adj.push((a, eid));
349                        }
350                    }
351                }
352            }
353        }
354        result
355    }
356}
357
358impl<N: Clone, E: Clone> Graph<N, E> {
359    pub fn to_adjacency_matrix(&self) -> AdjacencyMatrix {
360        let node_ids = self.node_ids();
361        let n = node_ids.len();
362        let mut index_map: HashMap<NodeId, usize> = HashMap::new();
363        for (i, &nid) in node_ids.iter().enumerate() {
364            index_map.insert(nid, i);
365        }
366        let mut matrix = vec![vec![0.0f32; n]; n];
367        for ed in self.edges.values() {
368            if let (Some(&i), Some(&j)) = (index_map.get(&ed.from), index_map.get(&ed.to)) {
369                matrix[i][j] = ed.weight;
370                if self.kind == GraphKind::Undirected {
371                    matrix[j][i] = ed.weight;
372                }
373            }
374        }
375        AdjacencyMatrix { matrix, node_ids, index_map }
376    }
377
378    pub fn to_edge_list(&self) -> EdgeList {
379        let edges: Vec<(NodeId, NodeId, f32)> = self.edges.values()
380            .map(|e| (e.from, e.to, e.weight))
381            .collect();
382        let node_ids = self.node_ids();
383        EdgeList { edges, node_ids }
384    }
385}
386
387// BFS iterator
388pub struct BfsIterator<'a, N, E> {
389    graph: &'a Graph<N, E>,
390    queue: VecDeque<NodeId>,
391    visited: HashSet<NodeId>,
392}
393
394impl<'a, N, E> Iterator for BfsIterator<'a, N, E> {
395    type Item = NodeId;
396    fn next(&mut self) -> Option<NodeId> {
397        let current = self.queue.pop_front()?;
398        for &(neighbor, _) in self.graph.adjacency.get(&current).unwrap_or(&Vec::new()) {
399            if self.visited.insert(neighbor) {
400                self.queue.push_back(neighbor);
401            }
402        }
403        Some(current)
404    }
405}
406
407// DFS iterator
408pub struct DfsIterator<'a, N, E> {
409    graph: &'a Graph<N, E>,
410    stack: Vec<NodeId>,
411    visited: HashSet<NodeId>,
412}
413
414impl<'a, N, E> Iterator for DfsIterator<'a, N, E> {
415    type Item = NodeId;
416    fn next(&mut self) -> Option<NodeId> {
417        let current = self.stack.pop()?;
418        for &(neighbor, _) in self.graph.adjacency.get(&current).unwrap_or(&Vec::new()) {
419            if self.visited.insert(neighbor) {
420                self.stack.push(neighbor);
421            }
422        }
423        Some(current)
424    }
425}
426
427// Adjacency matrix representation
428#[derive(Debug, Clone)]
429pub struct AdjacencyMatrix {
430    pub matrix: Vec<Vec<f32>>,
431    pub node_ids: Vec<NodeId>,
432    pub index_map: HashMap<NodeId, usize>,
433}
434
435impl AdjacencyMatrix {
436    pub fn to_graph(&self, kind: GraphKind) -> Graph<(), ()> {
437        let mut g = Graph::new(kind);
438        let mut id_map: HashMap<usize, NodeId> = HashMap::new();
439        for (i, &orig_id) in self.node_ids.iter().enumerate() {
440            let nid = g.add_node(());
441            id_map.insert(i, nid);
442        }
443        let n = self.matrix.len();
444        for i in 0..n {
445            let start_j = if kind == GraphKind::Undirected { i + 1 } else { 0 };
446            for j in start_j..n {
447                if self.matrix[i][j] != 0.0 {
448                    g.add_edge_weighted(id_map[&i], id_map[&j], (), self.matrix[i][j]);
449                }
450            }
451        }
452        g
453    }
454
455    pub fn get(&self, from: NodeId, to: NodeId) -> f32 {
456        let i = self.index_map.get(&from).copied().unwrap_or(0);
457        let j = self.index_map.get(&to).copied().unwrap_or(0);
458        self.matrix[i][j]
459    }
460}
461
462// Edge list representation
463#[derive(Debug, Clone)]
464pub struct EdgeList {
465    pub edges: Vec<(NodeId, NodeId, f32)>,
466    pub node_ids: Vec<NodeId>,
467}
468
469impl EdgeList {
470    pub fn to_graph(&self, kind: GraphKind) -> Graph<(), ()> {
471        let mut g = Graph::new(kind);
472        let mut id_map: HashMap<NodeId, NodeId> = HashMap::new();
473        for &orig_id in &self.node_ids {
474            let nid = g.add_node(());
475            id_map.insert(orig_id, nid);
476        }
477        for &(from, to, w) in &self.edges {
478            if let (Some(&f), Some(&t)) = (id_map.get(&from), id_map.get(&to)) {
479                g.add_edge_weighted(f, t, (), w);
480            }
481        }
482        g
483    }
484}
485
486#[cfg(test)]
487mod tests {
488    use super::*;
489
490    #[test]
491    fn test_add_remove_nodes() {
492        let mut g: Graph<&str, ()> = Graph::new(GraphKind::Undirected);
493        let a = g.add_node("A");
494        let b = g.add_node("B");
495        let c = g.add_node("C");
496        assert_eq!(g.node_count(), 3);
497        g.remove_node(b);
498        assert_eq!(g.node_count(), 2);
499        assert!(!g.has_node(b));
500    }
501
502    #[test]
503    fn test_add_remove_edges() {
504        let mut g: Graph<(), ()> = Graph::new(GraphKind::Undirected);
505        let a = g.add_node(());
506        let b = g.add_node(());
507        let e = g.add_edge(a, b, ());
508        assert_eq!(g.edge_count(), 1);
509        assert_eq!(g.degree(a), 1);
510        assert_eq!(g.degree(b), 1);
511        g.remove_edge(e);
512        assert_eq!(g.edge_count(), 0);
513        assert_eq!(g.degree(a), 0);
514    }
515
516    #[test]
517    fn test_neighbors() {
518        let mut g: Graph<(), ()> = Graph::new(GraphKind::Undirected);
519        let a = g.add_node(());
520        let b = g.add_node(());
521        let c = g.add_node(());
522        g.add_edge(a, b, ());
523        g.add_edge(a, c, ());
524        let mut nbrs = g.neighbors(a);
525        nbrs.sort();
526        assert_eq!(nbrs, vec![b, c]);
527    }
528
529    #[test]
530    fn test_directed() {
531        let mut g: Graph<(), ()> = Graph::new(GraphKind::Directed);
532        let a = g.add_node(());
533        let b = g.add_node(());
534        g.add_edge(a, b, ());
535        assert_eq!(g.neighbors(a), vec![b]);
536        assert!(g.neighbors(b).is_empty());
537    }
538
539    #[test]
540    fn test_bfs() {
541        let mut g: Graph<(), ()> = Graph::new(GraphKind::Undirected);
542        let a = g.add_node(());
543        let b = g.add_node(());
544        let c = g.add_node(());
545        let d = g.add_node(());
546        g.add_edge(a, b, ());
547        g.add_edge(b, c, ());
548        g.add_edge(c, d, ());
549        let bfs: Vec<NodeId> = g.bfs(a).collect();
550        assert_eq!(bfs.len(), 4);
551        assert_eq!(bfs[0], a);
552    }
553
554    #[test]
555    fn test_dfs() {
556        let mut g: Graph<(), ()> = Graph::new(GraphKind::Undirected);
557        let a = g.add_node(());
558        let b = g.add_node(());
559        let c = g.add_node(());
560        g.add_edge(a, b, ());
561        g.add_edge(a, c, ());
562        let dfs: Vec<NodeId> = g.dfs(a).collect();
563        assert_eq!(dfs.len(), 3);
564        assert_eq!(dfs[0], a);
565    }
566
567    #[test]
568    fn test_adjacency_matrix_roundtrip() {
569        let mut g: Graph<(), ()> = Graph::new(GraphKind::Undirected);
570        let a = g.add_node(());
571        let b = g.add_node(());
572        let c = g.add_node(());
573        g.add_edge_weighted(a, b, (), 2.0);
574        g.add_edge_weighted(b, c, (), 3.0);
575        let mat = g.to_adjacency_matrix();
576        assert_eq!(mat.matrix.len(), 3);
577        let g2 = mat.to_graph(GraphKind::Undirected);
578        assert_eq!(g2.node_count(), 3);
579        assert_eq!(g2.edge_count(), 2);
580    }
581
582    #[test]
583    fn test_edge_list() {
584        let mut g: Graph<(), ()> = Graph::new(GraphKind::Undirected);
585        let a = g.add_node(());
586        let b = g.add_node(());
587        g.add_edge(a, b, ());
588        let el = g.to_edge_list();
589        assert_eq!(el.edges.len(), 1);
590        let g2 = el.to_graph(GraphKind::Undirected);
591        assert_eq!(g2.edge_count(), 1);
592    }
593
594    #[test]
595    fn test_subgraph() {
596        let mut g: Graph<i32, ()> = Graph::new(GraphKind::Undirected);
597        let a = g.add_node(1);
598        let b = g.add_node(2);
599        let c = g.add_node(3);
600        g.add_edge(a, b, ());
601        g.add_edge(b, c, ());
602        g.add_edge(a, c, ());
603        let sub = g.subgraph(&[a, b]);
604        assert_eq!(sub.node_count(), 2);
605        assert_eq!(sub.edge_count(), 1);
606    }
607
608    #[test]
609    fn test_union() {
610        let mut g1: Graph<(), ()> = Graph::new(GraphKind::Undirected);
611        let a = g1.add_node(());
612        let b = g1.add_node(());
613        g1.add_edge(a, b, ());
614
615        let mut g2: Graph<(), ()> = Graph::new(GraphKind::Undirected);
616        let c = g2.add_node(());
617        let d = g2.add_node(());
618        g2.add_edge(c, d, ());
619
620        let u = g1.union(&g2);
621        assert_eq!(u.node_count(), 4);
622        assert_eq!(u.edge_count(), 2);
623    }
624
625    #[test]
626    fn test_complement() {
627        let mut g: Graph<(), ()> = Graph::new(GraphKind::Undirected);
628        let a = g.add_node(());
629        let b = g.add_node(());
630        let c = g.add_node(());
631        g.add_edge(a, b, ());
632        // Complete graph has 3 edges, we have 1, complement should have 2
633        let comp = g.complement();
634        assert_eq!(comp.edge_count(), 2);
635    }
636
637    #[test]
638    fn test_find_edge() {
639        let mut g: Graph<(), ()> = Graph::new(GraphKind::Undirected);
640        let a = g.add_node(());
641        let b = g.add_node(());
642        let e = g.add_edge(a, b, ());
643        assert_eq!(g.find_edge(a, b), Some(e));
644        assert_eq!(g.find_edge(b, a), Some(e));
645    }
646
647    #[test]
648    fn test_node_position() {
649        let mut g: Graph<(), ()> = Graph::new(GraphKind::Undirected);
650        let a = g.add_node_with_pos((), Vec2::new(3.0, 4.0));
651        assert_eq!(g.node_position(a), Vec2::new(3.0, 4.0));
652        g.set_node_position(a, Vec2::new(1.0, 2.0));
653        assert_eq!(g.node_position(a), Vec2::new(1.0, 2.0));
654    }
655
656    #[test]
657    fn test_weighted_edge() {
658        let mut g: Graph<(), ()> = Graph::new(GraphKind::Directed);
659        let a = g.add_node(());
660        let b = g.add_node(());
661        let e = g.add_edge_weighted(a, b, (), 5.0);
662        assert_eq!(g.edge_weight(e), 5.0);
663    }
664
665    #[test]
666    fn test_remove_node_removes_edges() {
667        let mut g: Graph<(), ()> = Graph::new(GraphKind::Undirected);
668        let a = g.add_node(());
669        let b = g.add_node(());
670        let c = g.add_node(());
671        g.add_edge(a, b, ());
672        g.add_edge(b, c, ());
673        g.remove_node(b);
674        assert_eq!(g.edge_count(), 0);
675        assert_eq!(g.degree(a), 0);
676        assert_eq!(g.degree(c), 0);
677    }
678}