ade_graph/implementations/
graph.rs

1use crate::implementations::FilteredGraph;
2use ade_traits::{EdgeTrait, GraphViewTrait, NodeTrait};
3use std::collections::HashMap;
4use std::fmt::Debug;
5
6#[derive(Debug)]
7pub struct Graph<N, E> {
8    nodes: HashMap<u32, N>,
9    edges: HashMap<(u32, u32), E>,
10}
11
12impl<N: NodeTrait, E: EdgeTrait> Graph<N, E> {
13    pub fn new(nodes: Vec<N>, edges: Vec<E>) -> Self {
14        let mut graph = Graph {
15            nodes: HashMap::with_capacity(nodes.len()),
16            edges: HashMap::with_capacity(edges.len()),
17        };
18
19        for node in nodes {
20            graph.add_node(node);
21        }
22
23        for edge in edges {
24            graph.add_edge(edge);
25        }
26
27        graph
28    }
29
30    pub fn add_node(&mut self, node: N) -> bool {
31        self.nodes.insert(node.key(), node).is_none()
32    }
33
34    pub fn remove_node(&mut self, key: u32) {
35        // Collect edge keys to remove before mutating
36        let edges_to_remove: Vec<(u32, u32)> = if let Some(node) = self.nodes.get(&key) {
37            let mut edges = Vec::new();
38
39            // Incoming edges (predecessor -> this node)
40            for predecessor in node.predecessors() {
41                edges.push((*predecessor, key));
42            }
43
44            // Outgoing edges (this node -> successor)
45            for successor in node.successors() {
46                edges.push((key, *successor));
47            }
48
49            edges
50        } else {
51            Vec::new()
52        };
53
54        // Remove all connected edges
55        for (source, target) in edges_to_remove {
56            self.remove_edge(source, target);
57        }
58
59        // Remove the node itself
60        self.nodes.remove(&key);
61    }
62
63    pub fn add_edge(&mut self, edge: E) -> bool {
64        if !self.nodes.contains_key(&edge.source()) || !self.nodes.contains_key(&edge.target()) {
65            return false;
66        }
67
68        // Update successors and predecessors
69        if let Some(source_node) = self.nodes.get_mut(&edge.source()) {
70            source_node.add_successor(edge.target());
71        }
72        if let Some(target_node) = self.nodes.get_mut(&edge.target()) {
73            target_node.add_predecessor(edge.source());
74        }
75
76        self.edges
77            .insert((edge.source(), edge.target()), edge)
78            .is_none()
79    }
80
81    pub fn remove_edge(&mut self, source: u32, target: u32) {
82        let edge_key = (source, target);
83
84        if self.edges.remove(&edge_key).is_some() {
85            // Update node connections
86            if let Some(source_node) = self.nodes.get_mut(&source) {
87                source_node.remove_successor(target);
88            }
89            if let Some(target_node) = self.nodes.get_mut(&target) {
90                target_node.remove_predecessor(source);
91            }
92        }
93    }
94
95    // pub fn get_subgraph_graph(&self, node_keys: &[u32]) -> Graph<N, E> {
96    //     let mut subgraph = Graph::<N, E>::new(Vec::new(), Vec::new());
97    //     let node_key_set: HashSet<u32> = node_keys.iter().copied().collect();
98
99    //     // Add nodes that are in the key set
100    //     for node in self.get_nodes() {
101    //         if node_key_set.contains(&node.key()) {
102    //             subgraph.add_node(node.fresh_copy());
103    //         }
104    //     }
105
106    //     // Add edges where both source and target are in the key set
107    //     for edge in self.get_edges() {
108    //         let source = edge.source();
109    //         let target = edge.target();
110    //         if node_key_set.contains(&source) && node_key_set.contains(&target) {
111    //             subgraph.add_edge(edge.clone());
112    //         }
113    //     }
114
115    //     subgraph
116    // }
117}
118
119impl<N: NodeTrait, E: EdgeTrait> GraphViewTrait<N, E> for Graph<N, E> {
120    fn is_empty(&self) -> bool {
121        self.nodes.is_empty()
122    }
123
124    fn has_sequential_keys(&self) -> bool {
125        let size = self.nodes.len();
126        if size == 0 {
127            return true;
128        }
129
130        // Quick checks first
131        if !self.nodes.contains_key(&0) || !self.nodes.contains_key(&(size as u32 - 1)) {
132            return false;
133        }
134
135        // Check if all keys are sequential
136        (0..size as u32).all(|i| self.nodes.contains_key(&i))
137    }
138
139    fn get_node(&self, key: u32) -> &N {
140        self.nodes
141            .get(&key)
142            .unwrap_or_else(|| panic!("Node {} not found", key))
143    }
144
145    fn has_node(&self, key: u32) -> bool {
146        self.nodes.contains_key(&key)
147    }
148
149    fn get_edge(&self, source: u32, target: u32) -> &E {
150        let edge_key = (source, target);
151        self.edges
152            .get(&edge_key)
153            .unwrap_or_else(|| panic!("Edge {}→{} not found", source, target))
154    }
155
156    fn has_edge(&self, source: u32, target: u32) -> bool {
157        let edge_key = (source, target);
158        self.edges.contains_key(&edge_key)
159    }
160
161    fn get_nodes<'a>(&'a self) -> impl Iterator<Item = &'a N>
162    where
163        N: 'a,
164    {
165        self.nodes.values()
166    }
167
168    fn get_node_keys(&self) -> impl Iterator<Item = u32> {
169        self.nodes.keys().copied()
170    }
171
172    fn get_edges<'a>(&'a self) -> impl Iterator<Item = &'a E>
173    where
174        E: 'a,
175    {
176        self.edges.values()
177    }
178
179    fn get_predecessors<'a>(&'a self, node_key: u32) -> impl Iterator<Item = &'a N>
180    where
181        N: 'a,
182    {
183        self.get_node(node_key)
184            .predecessors()
185            .iter()
186            .map(|pred_key| self.get_node(*pred_key))
187    }
188
189    fn get_predecessors_keys(&self, node_key: u32) -> impl Iterator<Item = u32> {
190        self.get_node(node_key).predecessors().iter().copied()
191    }
192
193    fn get_successors<'a>(&'a self, node_key: u32) -> impl Iterator<Item = &'a N>
194    where
195        N: 'a,
196    {
197        self.get_node(node_key)
198            .successors()
199            .iter()
200            .map(|succ_key| self.get_node(*succ_key))
201    }
202
203    fn get_successors_keys(&self, node_key: u32) -> impl Iterator<Item = u32> {
204        self.get_node(node_key).successors().iter().copied()
205    }
206
207    fn filter(&self, node_keys: &[u32]) -> impl GraphViewTrait<N, E> {
208        FilteredGraph::new(self, node_keys.iter().copied())
209    }
210}
211
212use std::fmt;
213
214impl<N: NodeTrait, E: EdgeTrait> fmt::Display for Graph<N, E> {
215    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
216        writeln!(f, "Nodes:")?;
217        for node in self.get_nodes() {
218            writeln!(
219                f,
220                "- {} (pred: {:?}, succ: {:?})",
221                node.key(),
222                node.predecessors(),
223                node.successors()
224            )?;
225        }
226
227        writeln!(f, "\nEdges:")?;
228        for edge in self.get_edges() {
229            writeln!(f, "- {} -> {}", edge.source(), edge.target())?;
230        }
231
232        Ok(())
233    }
234}
235
236#[cfg(test)]
237mod tests {
238    use super::*;
239    use crate::implementations::edge::Edge;
240    use crate::implementations::node::Node;
241
242    #[test]
243    fn test_is_empty() {
244        let graph = Graph::<Node, Edge>::new(Vec::new(), Vec::new());
245        assert!(graph.is_empty());
246    }
247
248    #[test]
249    fn test_add_node() {
250        let mut graph = Graph::<Node, Edge>::new(Vec::new(), Vec::new());
251
252        assert!(graph.add_node(Node::new(1))); // Adding a new node should return true
253        assert!(!graph.add_node(Node::new(1))); // Adding the same node again should return false
254    }
255
256    #[test]
257    fn test_get_node() {
258        let mut graph = Graph::<Node, Edge>::new(Vec::new(), Vec::new());
259
260        graph.add_node(Node::new(1));
261
262        let node = graph.get_node(1);
263        assert_eq!(node.key(), 1);
264        assert_eq!(node.predecessors().len(), 0);
265        assert_eq!(node.successors().len(), 0);
266    }
267
268    #[test]
269    #[should_panic(expected = "Node 2 not found")]
270    fn test_get_node_panic() {
271        let mut graph = Graph::<Node, Edge>::new(Vec::new(), Vec::new());
272        graph.add_node(Node::new(1));
273        graph.get_node(2);
274    }
275
276    #[test]
277    fn test_get_nodes() {
278        let mut graph = Graph::<Node, Edge>::new(Vec::new(), Vec::new());
279
280        graph.add_node(Node::new(1));
281        graph.add_node(Node::new(2));
282
283        assert_eq!(graph.get_nodes().count(), 2);
284
285        // assert it contains both nodes
286        assert!(graph.get_nodes().any(|n| n.key() == 1));
287        assert!(graph.get_nodes().any(|n| n.key() == 2));
288    }
289
290    #[test]
291    fn test_get_node_keys() {
292        let mut graph = Graph::<Node, Edge>::new(Vec::new(), Vec::new());
293
294        graph.add_node(Node::new(1));
295        graph.add_node(Node::new(2));
296
297        assert_eq!(graph.get_node_keys().count(), 2);
298
299        let node_keys = graph.get_node_keys().collect::<Vec<_>>();
300        // assert it contains both keys
301        assert!(node_keys.contains(&1));
302        assert!(node_keys.contains(&2));
303    }
304
305    #[test]
306    fn test_add_predecessor() {
307        let mut graph = Graph::<Node, Edge>::new(Vec::new(), Vec::new());
308
309        graph.add_node(Node::new(1));
310        graph.add_node(Node::new(2));
311        graph.add_edge(Edge::new(1, 2));
312
313        assert!(graph.get_node(2).predecessors().contains(&1));
314    }
315
316    #[test]
317    fn test_add_successor() {
318        let mut graph = Graph::<Node, Edge>::new(Vec::new(), Vec::new());
319
320        graph.add_node(Node::new(1));
321        graph.add_node(Node::new(2));
322        graph.add_edge(Edge::new(1, 2));
323
324        assert!(graph.get_node(1).successors().contains(&2));
325    }
326
327    #[test]
328    fn test_add_edge() {
329        let mut graph = Graph::<Node, Edge>::new(Vec::new(), Vec::new());
330
331        graph.add_node(Node::new(1));
332        graph.add_node(Node::new(2));
333
334        assert!(graph.add_edge(Edge::new(1, 2))); // Adding a new edge should return true
335        assert!(!graph.add_edge(Edge::new(1, 2))); // Adding the same edge again should return false
336
337        assert!(graph.has_edge(1, 2));
338
339        // Check predecessors and successors
340        assert!(graph.get_node(1).successors().contains(&2));
341        assert!(graph.get_node(2).predecessors().contains(&1));
342    }
343
344    #[test]
345    fn test_get_edge() {
346        let mut graph = Graph::<Node, Edge>::new(Vec::new(), Vec::new());
347
348        graph.add_node(Node::new(1));
349        graph.add_node(Node::new(2));
350
351        graph.add_edge(Edge::new(1, 2));
352
353        let edge = graph.get_edge(1, 2);
354        assert_eq!(edge.source(), 1);
355        assert_eq!(edge.target(), 2);
356    }
357
358    #[test]
359    #[should_panic(expected = "Edge 2→1 not found")]
360    fn test_get_edge_panic() {
361        let mut graph = Graph::<Node, Edge>::new(Vec::new(), Vec::new());
362        graph.add_node(Node::new(1));
363        graph.add_node(Node::new(2));
364        graph.add_edge(Edge::new(1, 2));
365        graph.get_edge(2, 1);
366    }
367
368    #[test]
369    fn test_predecessors() {
370        let mut graph = Graph::<Node, Edge>::new(Vec::new(), Vec::new());
371
372        graph.add_node(Node::new(1));
373        graph.add_node(Node::new(2));
374        graph.add_node(Node::new(3));
375
376        graph.add_edge(Edge::new(1, 2));
377        graph.add_edge(Edge::new(3, 2));
378
379        let predecessors = graph.get_node(2).predecessors();
380        assert_eq!(predecessors.len(), 2);
381        assert!(predecessors.contains(&1));
382        assert!(predecessors.contains(&3));
383    }
384
385    #[test]
386    fn test_successors() {
387        let mut graph = Graph::<Node, Edge>::new(Vec::new(), Vec::new());
388
389        graph.add_node(Node::new(1));
390        graph.add_node(Node::new(2));
391        graph.add_node(Node::new(3));
392
393        graph.add_edge(Edge::new(1, 2));
394        graph.add_edge(Edge::new(1, 3));
395
396        let successors = graph.get_node(1).successors();
397        assert_eq!(successors.len(), 2);
398        assert!(successors.contains(&2));
399        assert!(successors.contains(&3));
400    }
401
402    #[test]
403    fn test_remove_edge() {
404        let mut graph = Graph::<Node, Edge>::new(Vec::new(), Vec::new());
405
406        graph.add_node(Node::new(1));
407        graph.add_node(Node::new(2));
408
409        graph.add_edge(Edge::new(1, 2));
410        assert!(graph.has_edge(1, 2));
411
412        graph.remove_edge(1, 2);
413        assert!(!graph.has_edge(1, 2));
414
415        // Check that predecessors and successors are also removed
416        assert!(!graph.get_node(1).successors().contains(&2));
417        assert!(!graph.get_node(2).predecessors().contains(&1));
418    }
419
420    #[test]
421    fn test_remove_node() {
422        let mut graph = Graph::<Node, Edge>::new(Vec::new(), Vec::new());
423
424        graph.add_node(Node::new(1));
425        graph.add_node(Node::new(2));
426        graph.add_node(Node::new(3));
427
428        graph.add_edge(Edge::new(1, 2));
429        assert!(graph.has_edge(1, 2));
430
431        graph.add_edge(Edge::new(2, 3));
432        assert!(graph.has_edge(2, 3));
433
434        graph.add_edge(Edge::new(3, 1));
435        assert!(graph.has_edge(3, 1));
436
437        graph.remove_node(1);
438        assert!(!graph.has_node(1));
439        assert!(!graph.has_edge(1, 2));
440        assert!(!graph.has_edge(3, 1));
441        assert!(graph.has_edge(2, 3));
442
443        let node_b = graph.get_node(2);
444        assert!(!node_b.predecessors().contains(&1));
445        assert!(node_b.successors().contains(&3));
446
447        let node_c = graph.get_node(3);
448        assert!(node_c.predecessors().contains(&2));
449        assert!(!node_c.successors().contains(&1));
450    }
451
452    #[test]
453    fn test_get_predecessors() {
454        let mut graph = Graph::<Node, Edge>::new(Vec::new(), Vec::new());
455
456        graph.add_node(Node::new(1));
457        graph.add_node(Node::new(2));
458        graph.add_node(Node::new(3));
459
460        graph.add_edge(Edge::new(1, 2));
461        graph.add_edge(Edge::new(3, 2));
462
463        assert_eq!(graph.get_predecessors(2).count(), 2);
464        assert!(graph.get_predecessors(2).any(|n| n.key() == 1));
465        assert!(graph.get_predecessors(2).any(|n| n.key() == 3));
466    }
467
468    #[test]
469    fn test_get_successors() {
470        let mut graph = Graph::<Node, Edge>::new(Vec::new(), Vec::new());
471
472        graph.add_node(Node::new(1));
473        graph.add_node(Node::new(2));
474        graph.add_node(Node::new(3));
475
476        graph.add_edge(Edge::new(1, 2));
477        graph.add_edge(Edge::new(1, 3));
478
479        assert_eq!(graph.get_successors(1).count(), 2);
480        assert!(graph.get_successors(1).any(|n| n.key() == 2));
481        assert!(graph.get_successors(1).any(|n| n.key() == 3));
482    }
483
484    #[test]
485    fn test_get_successors_keys() {
486        let mut graph = Graph::<Node, Edge>::new(Vec::new(), Vec::new());
487        graph.add_node(Node::new(1));
488        graph.add_node(Node::new(2));
489        graph.add_node(Node::new(3));
490        graph.add_edge(Edge::new(1, 2));
491        graph.add_edge(Edge::new(1, 3));
492
493        let successors_keys: Vec<u32> = graph.get_successors_keys(1).collect();
494        assert_eq!(successors_keys.len(), 2);
495        assert!(successors_keys.contains(&2));
496        assert!(successors_keys.contains(&3));
497    }
498
499    #[test]
500    fn test_filter() {
501        let mut graph = Graph::<Node, Edge>::new(Vec::new(), Vec::new());
502
503        graph.add_node(Node::new(1));
504        graph.add_node(Node::new(2));
505        graph.add_node(Node::new(3));
506
507        graph.add_edge(Edge::new(1, 2));
508        graph.add_edge(Edge::new(1, 3));
509
510        // 2 nodes and 1 edge
511        let subgraph1 = graph.filter(&[1, 2]);
512        assert_eq!(subgraph1.get_nodes().count(), 2);
513        assert!(subgraph1.has_node(1));
514        assert!(subgraph1.has_node(2));
515        assert!(!subgraph1.has_node(3));
516        assert!(subgraph1.has_edge(1, 2));
517        assert!(!subgraph1.has_edge(1, 3));
518
519        // 1 node and 0 edges
520        let subgraph2 = graph.filter(&[1]);
521        assert_eq!(subgraph2.get_nodes().count(), 1);
522        assert!(subgraph2.has_node(1));
523        assert!(!subgraph2.has_node(2));
524        assert!(!subgraph2.has_node(3));
525        assert!(!subgraph2.has_edge(1, 2));
526        assert!(!subgraph2.has_edge(1, 3));
527
528        // 0 nodes and 0 edges (node 4 doesn't exist)
529        let subgraph3 = graph.filter(&[4]);
530        assert_eq!(subgraph3.get_nodes().count(), 0);
531        assert!(!subgraph3.has_node(1));
532        assert!(!subgraph3.has_node(2));
533        assert!(!subgraph3.has_node(3));
534        assert!(!subgraph3.has_edge(1, 2));
535        assert!(!subgraph3.has_edge(1, 3));
536    }
537
538    #[test]
539    fn test_has_node() {
540        let mut graph = Graph::<Node, Edge>::new(Vec::new(), Vec::new());
541
542        graph.add_node(Node::new(1));
543        assert!(graph.has_node(1));
544        assert!(!graph.has_node(2));
545    }
546
547    #[test]
548    fn test_has_edge() {
549        let mut graph = Graph::<Node, Edge>::new(Vec::new(), Vec::new());
550
551        graph.add_node(Node::new(1));
552        graph.add_node(Node::new(2));
553
554        assert!(!graph.has_edge(1, 2));
555        graph.add_edge(Edge::new(1, 2));
556        assert!(graph.has_edge(1, 2));
557        assert!(!graph.has_edge(2, 1)); // Directed graph
558    }
559
560    #[test]
561    fn test_get_edges() {
562        let mut graph = Graph::<Node, Edge>::new(Vec::new(), Vec::new());
563
564        graph.add_node(Node::new(1));
565        graph.add_node(Node::new(2));
566        graph.add_node(Node::new(3));
567
568        graph.add_edge(Edge::new(1, 2));
569        graph.add_edge(Edge::new(2, 3));
570
571        assert_eq!(graph.get_edges().count(), 2);
572        assert!(graph
573            .get_edges()
574            .any(|e| e.source() == 1 && e.target() == 2));
575        assert!(graph
576            .get_edges()
577            .any(|e| e.source() == 2 && e.target() == 3));
578    }
579
580    #[test]
581    fn test_has_sequential_keys() {
582        let mut graph = Graph::<Node, Edge>::new(Vec::new(), Vec::new());
583
584        graph.add_node(Node::new(0));
585        graph.add_node(Node::new(1));
586        graph.add_node(Node::new(2));
587
588        assert!(graph.has_sequential_keys());
589    }
590}