1use serde::*;
3use std::collections::HashMap;
4
5use petgraph::prelude::*;
6pub use petgraph::prelude::NodeIndex;
10#[derive(Clone, Debug, Default, Deserialize, Serialize)]
14pub struct NxGraph<N, E>
16where
17 N: Default,
18 E: Default,
19{
20 mapping: HashMap<String, EdgeIndex>,
21 graph: StableUnGraph<N, E>,
22}
23
24fn node_pair_key(n1: NodeIndex, n2: NodeIndex) -> String {
28 let v = if n1 > n2 { [n2, n1] } else { [n1, n2] };
29 format!("{}-{}", v[0].index(), v[1].index())
30}
31
32impl<N, E> NxGraph<N, E>
33where
34 N: Default,
35 E: Default,
36{
37 fn edge_index_between(&self, n1: NodeIndex, n2: NodeIndex) -> Option<EdgeIndex> {
38 self.mapping.get(&node_pair_key(n1, n2)).map(|v| *v)
45 }
46
47 fn get_node_data(&self, n: NodeIndex) -> &N {
49 self.graph.node_weight(n).expect("no node")
50 }
51
52 fn get_node_data_mut(&mut self, n: NodeIndex) -> &mut N {
54 self.graph.node_weight_mut(n).expect("no node")
55 }
56
57 fn get_edge_data(&self, node1: NodeIndex, node2: NodeIndex) -> &E {
59 let edge_index = self.edge_index_between(node1, node2).expect("no edge index");
60 self.graph.edge_weight(edge_index).expect("no edge")
61 }
62
63 fn get_edge_data_mut(&mut self, node1: NodeIndex, node2: NodeIndex) -> &mut E {
65 let edge_index = self.edge_index_between(node1, node2).expect("no edge index");
66 self.graph.edge_weight_mut(edge_index).expect("no edge")
67 }
68}
69impl<N, E> NxGraph<N, E>
94where
95 N: Default,
96 E: Default,
97{
98 pub fn new() -> Self {
100 Self { ..Default::default() }
101 }
102
103 pub fn neighbors(&self, n: NodeIndex) -> impl Iterator<Item = NodeIndex> + '_ {
109 self.graph.neighbors(n)
110 }
111
112 pub fn node_indices(&self) -> impl Iterator<Item = NodeIndex> + '_ {
114 self.graph.node_indices()
115 }
116
117 pub fn has_node(&self, n: NodeIndex) -> bool {
119 self.graph.contains_node(n)
120 }
121
122 pub fn has_edge(&self, u: NodeIndex, v: NodeIndex) -> bool {
124 self.graph.find_edge(u, v).is_some()
125 }
126
127 pub fn number_of_nodes(&self) -> usize {
129 self.graph.node_count()
130 }
131
132 pub fn number_of_edges(&self) -> usize {
134 self.graph.edge_count()
135 }
136
137 pub fn add_node(&mut self, data: N) -> NodeIndex {
139 self.graph.add_node(data)
140 }
141
142 pub fn add_nodes_from<M: IntoIterator<Item = N>>(&mut self, nodes: M) -> Vec<NodeIndex> {
144 nodes.into_iter().map(|node| self.add_node(node)).collect()
145 }
146
147 pub fn add_edge(&mut self, u: NodeIndex, v: NodeIndex, data: E) {
155 assert_ne!(u, v, "self-loop is not allowed!");
156
157 let e = self.graph.update_edge(u, v, data);
159
160 self.mapping.insert(node_pair_key(u, v), e);
162 }
163
164 pub fn add_edges_from<M: IntoIterator<Item = (NodeIndex, NodeIndex, E)>>(&mut self, edges: M) {
166 for (u, v, d) in edges {
167 self.add_edge(u, v, d);
168 }
169 }
170
171 pub fn remove_edge(&mut self, node1: NodeIndex, node2: NodeIndex) -> Option<E> {
174 if let Some(e) = self.mapping.remove(&node_pair_key(node1, node2)) {
175 self.graph.remove_edge(e)
176 } else {
177 None
178 }
179 }
180
181 pub fn remove_node(&mut self, n: NodeIndex) -> Option<N> {
184 self.graph.remove_node(n)
185 }
186
187 pub fn clear(&mut self) {
189 self.graph.clear();
190 }
191
192 pub fn clear_edges(&mut self) {
194 self.graph.clear_edges()
195 }
196}
197impl<N, E> NxGraph<N, E>
201where
202 N: Default,
203 E: Default,
204{
205 pub fn raw_graph(&self) -> &StableUnGraph<N, E> {
207 &self.graph
208 }
209
210 pub fn raw_graph_mut(&mut self) -> &mut StableUnGraph<N, E> {
212 &mut self.graph
213 }
214}
215#[cfg(feature = "adhoc")]
219impl<N, E> NxGraph<N, E>
220where
221 N: Default + Clone,
222 E: Default + Clone,
223{
224 pub fn get_node(&self, n: NodeIndex) -> Option<&N> {
227 self.graph.node_weight(n)
228 }
229
230 pub fn get_edge(&self, u: NodeIndex, v: NodeIndex) -> Option<&E> {
233 let ei = self.edge_index_between(u, v)?;
234 self.graph.edge_weight(ei)
235 }
236
237 pub fn get_edge_mut(&mut self, u: NodeIndex, v: NodeIndex) -> Option<&mut E> {
240 let ei = self.edge_index_between(u, v)?;
241 self.graph.edge_weight_mut(ei)
242 }
243}
244impl<N, E> NxGraph<N, E>
249where
250 N: Default + Clone,
251 E: Default + Clone,
252{
253 pub fn from_raw_graph(graph: StableUnGraph<N, E>) -> Self {
255 let edges: Vec<_> = graph
256 .edge_indices()
257 .map(|e| {
258 let (u, v) = graph.edge_endpoints(e).unwrap();
259 let edata = graph.edge_weight(e).unwrap().to_owned();
260 (u, v, edata)
261 })
262 .collect();
263
264 let mut g = Self { graph, ..Default::default() };
265 g.add_edges_from(edges);
266 g
267 }
268}
269
270impl NxGraph<usize, usize> {
271 pub fn path_graph(n: usize) -> Self {
274 let mut g = Self::new();
275 let nodes = g.add_nodes_from(1..=n);
276
277 for p in nodes.windows(2) {
278 g.add_edge(p[0], p[1], 0)
279 }
280
281 g
282 }
283}
284
285#[test]
286fn test_path_graph() {
287 let g = NxGraph::path_graph(5);
288 assert_eq!(g.number_of_nodes(), 5);
289 assert_eq!(g.number_of_edges(), 4);
290}
291impl<N, E> std::ops::Index<NodeIndex> for NxGraph<N, E>
295where
296 N: Default,
297 E: Default,
298{
299 type Output = N;
300
301 fn index(&self, n: NodeIndex) -> &Self::Output {
302 self.get_node_data(n)
303 }
304}
305
306impl<N, E> std::ops::IndexMut<NodeIndex> for NxGraph<N, E>
307where
308 N: Default,
309 E: Default,
310{
311 fn index_mut(&mut self, n: NodeIndex) -> &mut Self::Output {
312 self.get_node_data_mut(n)
313 }
314}
315impl<N, E> std::ops::Index<(NodeIndex, NodeIndex)> for NxGraph<N, E>
319where
320 N: Default,
321 E: Default,
322{
323 type Output = E;
324
325 fn index(&self, e: (NodeIndex, NodeIndex)) -> &Self::Output {
326 self.get_edge_data(e.0, e.1)
327 }
328}
329
330impl<N, E> std::ops::IndexMut<(NodeIndex, NodeIndex)> for NxGraph<N, E>
331where
332 N: Default,
333 E: Default,
334{
335 fn index_mut(&mut self, e: (NodeIndex, NodeIndex)) -> &mut Self::Output {
336 self.get_edge_data_mut(e.0, e.1)
337 }
338}
339pub struct Nodes<'a, N, E>
344where
345 N: Default,
346 E: Default,
347{
348 nodes: std::vec::IntoIter<NodeIndex>,
350
351 parent: &'a NxGraph<N, E>,
353}
354
355impl<'a, N, E> Nodes<'a, N, E>
356where
357 N: Default,
358 E: Default,
359{
360 fn new(g: &'a NxGraph<N, E>) -> Self {
361 let nodes: Vec<_> = g.graph.node_indices().collect();
362
363 Self {
364 parent: g,
365 nodes: nodes.into_iter(),
366 }
367 }
368}
369
370impl<'a, N, E> Iterator for Nodes<'a, N, E>
371where
372 N: Default,
373 E: Default,
374{
375 type Item = (NodeIndex, &'a N);
376
377 fn next(&mut self) -> Option<Self::Item> {
378 if let Some(cur) = self.nodes.next() {
379 Some((cur, &self.parent.graph[cur]))
380 } else {
381 None
382 }
383 }
384}
385
386impl<'a, N, E> std::ops::Index<NodeIndex> for Nodes<'a, N, E>
387where
388 N: Default,
389 E: Default,
390{
391 type Output = N;
392
393 fn index(&self, n: NodeIndex) -> &Self::Output {
394 &self.parent[n]
395 }
396}
397pub struct Edges<'a, N, E>
402where
403 N: Default,
404 E: Default,
405{
406 parent: &'a NxGraph<N, E>,
408
409 edges: std::vec::IntoIter<EdgeIndex>,
411}
412
413impl<'a, N, E> Edges<'a, N, E>
414where
415 N: Default,
416 E: Default,
417{
418 fn new(g: &'a NxGraph<N, E>) -> Self {
419 let edges: Vec<_> = g.graph.edge_indices().collect();
420
421 Self {
422 parent: g,
423 edges: edges.into_iter(),
424 }
425 }
426}
427
428impl<'a, N, E> Iterator for Edges<'a, N, E>
429where
430 N: Default,
431 E: Default,
432{
433 type Item = (NodeIndex, NodeIndex, &'a E);
434
435 fn next(&mut self) -> Option<Self::Item> {
437 if let Some(cur) = self.edges.next() {
438 let (u, v) = self
439 .parent
440 .graph
441 .edge_endpoints(cur)
442 .expect("no graph endpoints");
443 let edge_data = &self.parent.graph[cur];
444 Some((u, v, edge_data))
445 } else {
446 None
447 }
448 }
449}
450
451impl<'a, N, E> std::ops::Index<(NodeIndex, NodeIndex)> for Edges<'a, N, E>
452where
453 N: Default,
454 E: Default,
455{
456 type Output = E;
457
458 fn index(&self, e: (NodeIndex, NodeIndex)) -> &Self::Output {
459 &self.parent[e]
460 }
461}
462impl<N, E> NxGraph<N, E>
504where
505 N: Default,
506 E: Default,
507{
508 pub fn nodes(&self) -> Nodes<N, E> {
514 Nodes::new(self)
515 }
516
517 pub fn edges(&self) -> Edges<N, E> {
523 Edges::new(self)
524 }
525}
526#[cfg(test)]
530mod test {
531 use super::*;
532
533 #[derive(Clone, Default, Debug, PartialEq)]
534 struct Edge {
535 weight: f64,
536 }
537
538 impl Edge {
539 fn new(weight: f64) -> Self {
540 Self { weight }
541 }
542 }
543
544 #[derive(Clone, Default, Debug, PartialEq)]
545 struct Node {
546 position: [f64; 3],
548 }
549
550 #[test]
551 fn test_graph() {
552 let mut g = NxGraph::new();
554 let n1 = g.add_node(Node::default());
555 let n2 = g.add_node(Node::default());
556 let n3 = g.add_node(Node::default());
557
558 g.add_edge(n1, n2, Edge { weight: 1.0 });
560 assert_eq!(1, g.number_of_edges());
561
562 g.add_edge(n1, n2, Edge { weight: 2.0 });
564 assert_eq!(1, g.number_of_edges());
565 assert_eq!(g[(n1, n2)].weight, 2.0);
567
568 g.add_edge(n1, n3, Edge::default());
569 let n4 = g.add_node(Node::default());
570 let _ = g.remove_node(n4);
571 assert_eq!(g.number_of_nodes(), 3);
572 assert_eq!(g.number_of_edges(), 2);
573
574 let node = Node { position: [1.0; 3] };
576 let n4 = g.add_node(node.clone());
577 let edge = Edge { weight: 2.2 };
578 g.add_edge(n1, n4, edge.clone());
579 let x = g.remove_edge(n2, n4);
580 assert_eq!(x, None);
581 let x = g.remove_edge(n1, n4);
582 assert_eq!(x, Some(edge));
583 let x = g.remove_node(n4);
584 assert_eq!(x, Some(node));
585
586 assert!(g.has_node(n1));
588 assert!(g.has_edge(n1, n2));
589 assert!(!g.has_edge(n2, n3));
590 let _ = g.remove_edge(n1, n3);
591 assert_eq!(g.number_of_edges(), 1);
592 assert!(!g.has_edge(n1, n3));
593
594 g[n1].position = [1.9; 3];
596
597 let nodes = g.nodes();
599 assert_eq!(nodes[n1].position, [1.9; 3]);
600
601 g[(n1, n2)].weight = 0.3;
603
604 let edges = g.edges();
606 assert_eq!(edges[(n1, n2)].weight, 0.3);
607 assert_eq!(edges[(n2, n1)].weight, 0.3);
608
609 for (u, node_data) in g.nodes() {
611 dbg!(u, node_data);
612 }
613
614 for (u, v, edge_data) in g.edges() {
616 dbg!(u, v, edge_data);
617 }
618
619 for u in g.neighbors(n1) {
621 dbg!(&g[u]);
622 }
623
624 g.clear();
626 assert_eq!(g.number_of_nodes(), 0);
627 assert_eq!(g.number_of_edges(), 0);
628 }
629
630 #[test]
631 #[should_panic]
632 fn test_speical_graph() {
633 let mut g = NxGraph::new();
634 let n1 = g.add_node(Node::default());
635 let n2 = g.add_node(Node::default());
636
637 g.add_edge(n1, n2, Edge::new(1.0));
638 assert_eq!(g[(n1, n2)].weight, 1.0);
639 assert_eq!(g[(n2, n1)].weight, 1.0);
640
641 g.add_edge(n2, n1, Edge::new(2.0));
643 assert_eq!(g[(n1, n2)].weight, 2.0);
644
645 g.add_edge(n2, n2, Edge::default());
647 }
648}
649