gchemol_graph/
graph.rs

1// [[file:../nxgraph.note::*imports][imports:1]]
2use serde::*;
3use std::collections::HashMap;
4
5use petgraph::prelude::*;
6// imports:1 ends here
7
8// [[file:../nxgraph.note::*exports][exports:1]]
9pub use petgraph::prelude::NodeIndex;
10// exports:1 ends here
11
12// [[file:../nxgraph.note::dfa4e9ef][dfa4e9ef]]
13#[derive(Clone, Debug, Default, Deserialize, Serialize)]
14/// networkx-like API wrapper around petgraph
15pub struct NxGraph<N, E>
16where
17    N: Default,
18    E: Default,
19{
20    mapping: HashMap<String, EdgeIndex>,
21    graph: StableUnGraph<N, E>,
22}
23
24/// return sorted node pair as mapping key.
25// NOTE: if we return a tuple or array, we will encounter not string
26// key error for serde_json
27fn node_pair_key(n1: NodeIndex, n2: NodeIndex) -> String {
28    let v = if n1 > n2 { [n2, n1] } else { [n1, n2] };
29    format!("{}-{}", v[0].index(), v[1].index())
30}
31
32impl<N, E> NxGraph<N, E>
33where
34    N: Default,
35    E: Default,
36{
37    fn edge_index_between(&self, n1: NodeIndex, n2: NodeIndex) -> Option<EdgeIndex> {
38        // this is slow
39        // self.graph.find_edge(n1, n2)
40
41        // make sure n1 is always smaller than n2.
42        // let (n1, n2) = if n1 > n2 { (n2, n1) } else { (n1, n2) };
43
44        self.mapping.get(&node_pair_key(n1, n2)).map(|v| *v)
45    }
46
47    /// Return data associated with node `n`.
48    fn get_node_data(&self, n: NodeIndex) -> &N {
49        self.graph.node_weight(n).expect("no node")
50    }
51
52    /// Return a mutable reference of data associated with node `n`.
53    fn get_node_data_mut(&mut self, n: NodeIndex) -> &mut N {
54        self.graph.node_weight_mut(n).expect("no node")
55    }
56
57    /// Return data associated with edge `node1--node2`.
58    fn get_edge_data(&self, node1: NodeIndex, node2: NodeIndex) -> &E {
59        let edge_index = self.edge_index_between(node1, node2).expect("no edge index");
60        self.graph.edge_weight(edge_index).expect("no edge")
61    }
62
63    /// Return a mutable reference of data associated with edge `node1--node2`.
64    fn get_edge_data_mut(&mut self, node1: NodeIndex, node2: NodeIndex) -> &mut E {
65        let edge_index = self.edge_index_between(node1, node2).expect("no edge index");
66        self.graph.edge_weight_mut(edge_index).expect("no edge")
67    }
68}
69// dfa4e9ef ends here
70
71// [[file:../nxgraph.note::*base][base:1]]
72/// Build/Read/Edit Graph
73///
74/// # Example
75///
76/// ```
77/// use gchemol_graph::NxGraph;
78/// 
79/// let mut g = NxGraph::path_graph(2);
80/// let u = g.add_node(2);
81/// let v = g.add_node(3);
82/// g.add_edge(u, v, 5);
83/// 
84/// assert!(g.has_node(u));
85/// assert!(g.has_edge(u, v));
86/// 
87/// // loop over neighbors of node u
88/// for x in g.neighbors(u) {
89///     dbg!(x);
90/// }
91/// ```
92///
93impl<N, E> NxGraph<N, E>
94where
95    N: Default,
96    E: Default,
97{
98    /// Build a default Graph
99    pub fn new() -> Self {
100        Self { ..Default::default() }
101    }
102
103    /// Returns an iterator over all neighbors of node `n`.
104    ///
105    /// # Reference
106    ///
107    /// * https://networkx.github.io/documentation/stable/reference/classes/generated/networkx.Graph.neighbors.html
108    pub fn neighbors(&self, n: NodeIndex) -> impl Iterator<Item = NodeIndex> + '_ {
109        self.graph.neighbors(n)
110    }
111
112    /// Return an iterator over the node indices of the graph
113    pub fn node_indices(&self) -> impl Iterator<Item = NodeIndex> + '_ {
114        self.graph.node_indices()
115    }
116
117    /// Returns true if the graph contains the node n.
118    pub fn has_node(&self, n: NodeIndex) -> bool {
119        self.graph.contains_node(n)
120    }
121
122    /// Returns true if the edge (u, v) is in the graph.
123    pub fn has_edge(&self, u: NodeIndex, v: NodeIndex) -> bool {
124        self.graph.find_edge(u, v).is_some()
125    }
126
127    /// Returns the number of nodes in the graph.
128    pub fn number_of_nodes(&self) -> usize {
129        self.graph.node_count()
130    }
131
132    /// Returns the number of edges in the graph.
133    pub fn number_of_edges(&self) -> usize {
134        self.graph.edge_count()
135    }
136
137    /// Add a node with associated data into graph.
138    pub fn add_node(&mut self, data: N) -> NodeIndex {
139        self.graph.add_node(data)
140    }
141
142    /// Add multiple nodes.
143    pub fn add_nodes_from<M: IntoIterator<Item = N>>(&mut self, nodes: M) -> Vec<NodeIndex> {
144        nodes.into_iter().map(|node| self.add_node(node)).collect()
145    }
146
147    /// Add an edge with `data` between `u` and `v` (no parallel edge). If edge
148    /// u--v already exists, the associated data will be updated.
149    ///
150    /// # Panics
151    ///
152    /// * To avoid self-loop, this method will panic if node `u` and `v` are the
153    /// same.
154    pub fn add_edge(&mut self, u: NodeIndex, v: NodeIndex, data: E) {
155        assert_ne!(u, v, "self-loop is not allowed!");
156
157        // not add_edge for avoidding parallel edges
158        let e = self.graph.update_edge(u, v, data);
159
160        // update node pair to edge index mapping.
161        self.mapping.insert(node_pair_key(u, v), e);
162    }
163
164    /// Add multiple edges from `edges`.
165    pub fn add_edges_from<M: IntoIterator<Item = (NodeIndex, NodeIndex, E)>>(&mut self, edges: M) {
166        for (u, v, d) in edges {
167            self.add_edge(u, v, d);
168        }
169    }
170
171    /// Remove an edge between `node1` and `node2`. Return None if trying to
172    /// remove a non-existent edge.
173    pub fn remove_edge(&mut self, node1: NodeIndex, node2: NodeIndex) -> Option<E> {
174        if let Some(e) = self.mapping.remove(&node_pair_key(node1, node2)) {
175            self.graph.remove_edge(e)
176        } else {
177            None
178        }
179    }
180
181    /// Removes the node `n` and all adjacent edges. Return None if trying to
182    /// remove a non-existent node.
183    pub fn remove_node(&mut self, n: NodeIndex) -> Option<N> {
184        self.graph.remove_node(n)
185    }
186
187    /// Remove all nodes and edges
188    pub fn clear(&mut self) {
189        self.graph.clear();
190    }
191
192    /// Remove all edges
193    pub fn clear_edges(&mut self) {
194        self.graph.clear_edges()
195    }
196}
197// base:1 ends here
198
199// [[file:../nxgraph.note::*extra][extra:1]]
200impl<N, E> NxGraph<N, E>
201where
202    N: Default,
203    E: Default,
204{
205    /// Provides read access to raw Graph struct.
206    pub fn raw_graph(&self) -> &StableUnGraph<N, E> {
207        &self.graph
208    }
209
210    /// Provides mut access to raw Graph struct.
211    pub fn raw_graph_mut(&mut self) -> &mut StableUnGraph<N, E> {
212        &mut self.graph
213    }
214}
215// extra:1 ends here
216
217// [[file:../nxgraph.note::a03268f7][a03268f7]]
218#[cfg(feature = "adhoc")]
219impl<N, E> NxGraph<N, E>
220where
221    N: Default + Clone,
222    E: Default + Clone,
223{
224    /// Return the `Node` associated with node index `n`. Return None if no such
225    /// node `n`.
226    pub fn get_node(&self, n: NodeIndex) -> Option<&N> {
227        self.graph.node_weight(n)
228    }
229
230    /// Return the associated edge data between node `u` and `v`. Return None if
231    /// no such edge.
232    pub fn get_edge(&self, u: NodeIndex, v: NodeIndex) -> Option<&E> {
233        let ei = self.edge_index_between(u, v)?;
234        self.graph.edge_weight(ei)
235    }
236
237    /// Return mutable access to the associated edge data between node `u` and `v`. Return None if
238    /// no such edge.
239    pub fn get_edge_mut(&mut self, u: NodeIndex, v: NodeIndex) -> Option<&mut E> {
240        let ei = self.edge_index_between(u, v)?;
241        self.graph.edge_weight_mut(ei)
242    }
243}
244// a03268f7 ends here
245
246// [[file:../nxgraph.note::d04c099c][d04c099c]]
247/// Methods for creating `NxGraph` struct
248impl<N, E> NxGraph<N, E>
249where
250    N: Default + Clone,
251    E: Default + Clone,
252{
253    /// Return `NxGraph` from raw petgraph struct.
254    pub fn from_raw_graph(graph: StableUnGraph<N, E>) -> Self {
255        let edges: Vec<_> = graph
256            .edge_indices()
257            .map(|e| {
258                let (u, v) = graph.edge_endpoints(e).unwrap();
259                let edata = graph.edge_weight(e).unwrap().to_owned();
260                (u, v, edata)
261            })
262            .collect();
263
264        let mut g = Self { graph, ..Default::default() };
265        g.add_edges_from(edges);
266        g
267    }
268}
269
270impl NxGraph<usize, usize> {
271    /// Returns the Path graph `P_n` of linearly connected nodes. Node data and
272    /// edge data are usize type, mainly for test purpose.
273    pub fn path_graph(n: usize) -> Self {
274        let mut g = Self::new();
275        let nodes = g.add_nodes_from(1..=n);
276
277        for p in nodes.windows(2) {
278            g.add_edge(p[0], p[1], 0)
279        }
280
281        g
282    }
283}
284
285#[test]
286fn test_path_graph() {
287    let g = NxGraph::path_graph(5);
288    assert_eq!(g.number_of_nodes(), 5);
289    assert_eq!(g.number_of_edges(), 4);
290}
291// d04c099c ends here
292
293// [[file:../nxgraph.note::*node][node:1]]
294impl<N, E> std::ops::Index<NodeIndex> for NxGraph<N, E>
295where
296    N: Default,
297    E: Default,
298{
299    type Output = N;
300
301    fn index(&self, n: NodeIndex) -> &Self::Output {
302        self.get_node_data(n)
303    }
304}
305
306impl<N, E> std::ops::IndexMut<NodeIndex> for NxGraph<N, E>
307where
308    N: Default,
309    E: Default,
310{
311    fn index_mut(&mut self, n: NodeIndex) -> &mut Self::Output {
312        self.get_node_data_mut(n)
313    }
314}
315// node:1 ends here
316
317// [[file:../nxgraph.note::*edge][edge:1]]
318impl<N, E> std::ops::Index<(NodeIndex, NodeIndex)> for NxGraph<N, E>
319where
320    N: Default,
321    E: Default,
322{
323    type Output = E;
324
325    fn index(&self, e: (NodeIndex, NodeIndex)) -> &Self::Output {
326        self.get_edge_data(e.0, e.1)
327    }
328}
329
330impl<N, E> std::ops::IndexMut<(NodeIndex, NodeIndex)> for NxGraph<N, E>
331where
332    N: Default,
333    E: Default,
334{
335    fn index_mut(&mut self, e: (NodeIndex, NodeIndex)) -> &mut Self::Output {
336        self.get_edge_data_mut(e.0, e.1)
337    }
338}
339// edge:1 ends here
340
341// [[file:../nxgraph.note::*nodes][nodes:1]]
342/// Node view of graph, created with [nodes](struct.NxGraph.html#method.nodes) method.
343pub struct Nodes<'a, N, E>
344where
345    N: Default,
346    E: Default,
347{
348    /// An iterator over graph node indices.
349    nodes: std::vec::IntoIter<NodeIndex>,
350
351    /// Parent graph struct.
352    parent: &'a NxGraph<N, E>,
353}
354
355impl<'a, N, E> Nodes<'a, N, E>
356where
357    N: Default,
358    E: Default,
359{
360    fn new(g: &'a NxGraph<N, E>) -> Self {
361        let nodes: Vec<_> = g.graph.node_indices().collect();
362
363        Self {
364            parent: g,
365            nodes: nodes.into_iter(),
366        }
367    }
368}
369
370impl<'a, N, E> Iterator for Nodes<'a, N, E>
371where
372    N: Default,
373    E: Default,
374{
375    type Item = (NodeIndex, &'a N);
376
377    fn next(&mut self) -> Option<Self::Item> {
378        if let Some(cur) = self.nodes.next() {
379            Some((cur, &self.parent.graph[cur]))
380        } else {
381            None
382        }
383    }
384}
385
386impl<'a, N, E> std::ops::Index<NodeIndex> for Nodes<'a, N, E>
387where
388    N: Default,
389    E: Default,
390{
391    type Output = N;
392
393    fn index(&self, n: NodeIndex) -> &Self::Output {
394        &self.parent[n]
395    }
396}
397// nodes:1 ends here
398
399// [[file:../nxgraph.note::*edges][edges:1]]
400/// Edge view of graph, created with [edges](struct.NxGraph.html#method.edges) method.
401pub struct Edges<'a, N, E>
402where
403    N: Default,
404    E: Default,
405{
406    /// Parent graph struct
407    parent: &'a NxGraph<N, E>,
408
409    /// An iterator over graph edge indices
410    edges: std::vec::IntoIter<EdgeIndex>,
411}
412
413impl<'a, N, E> Edges<'a, N, E>
414where
415    N: Default,
416    E: Default,
417{
418    fn new(g: &'a NxGraph<N, E>) -> Self {
419        let edges: Vec<_> = g.graph.edge_indices().collect();
420
421        Self {
422            parent: g,
423            edges: edges.into_iter(),
424        }
425    }
426}
427
428impl<'a, N, E> Iterator for Edges<'a, N, E>
429where
430    N: Default,
431    E: Default,
432{
433    type Item = (NodeIndex, NodeIndex, &'a E);
434
435    /// Returns a tuple in (index_i, index_j, Edge) format.
436    fn next(&mut self) -> Option<Self::Item> {
437        if let Some(cur) = self.edges.next() {
438            let (u, v) = self
439                .parent
440                .graph
441                .edge_endpoints(cur)
442                .expect("no graph endpoints");
443            let edge_data = &self.parent.graph[cur];
444            Some((u, v, edge_data))
445        } else {
446            None
447        }
448    }
449}
450
451impl<'a, N, E> std::ops::Index<(NodeIndex, NodeIndex)> for Edges<'a, N, E>
452where
453    N: Default,
454    E: Default,
455{
456    type Output = E;
457
458    fn index(&self, e: (NodeIndex, NodeIndex)) -> &Self::Output {
459        &self.parent[e]
460    }
461}
462// edges:1 ends here
463
464// [[file:../nxgraph.note::*pub][pub:1]]
465/// Node view and Edge view for `NxGraph`.
466///
467/// # Example
468///
469/// ```
470/// use gchemol_graph::NxGraph;
471/// 
472/// let mut g = NxGraph::path_graph(3);
473/// let u = g.add_node(5);
474/// let v = g.add_node(2);
475/// let w = g.add_node(1);
476/// g.add_edge(u, v, 7);
477/// g.add_edge(u, w, 6);
478/// 
479/// // loop over nodes
480/// for (node_index, node_data) in g.nodes() {
481///     // do something
482/// }
483/// 
484/// // get node data of node `u`
485/// let nodes = g.nodes();
486/// let node_u = nodes[u];
487/// assert_eq!(node_u, 5);
488/// 
489/// // Collect nodes into HashMap
490/// let nodes: std::collections::HashMap<_, _> = g.nodes().collect();
491/// assert_eq!(nodes.len(), 6);
492/// 
493/// // loop over edges
494/// for (u, v, edge_data) in g.edges() {
495///     // dbg!(u, v, edge_data)
496/// }
497/// 
498/// // get edge data
499/// let edges = g.edges();
500/// let edge_uv = edges[(u, v)];
501/// assert_eq!(edge_uv, 7);
502/// ```
503impl<N, E> NxGraph<N, E>
504where
505    N: Default,
506    E: Default,
507{
508    /// A Node view of the Graph.
509    ///
510    /// # Reference
511    ///
512    /// * https://networkx.github.io/documentation/stable/reference/classes/generated/networkx.Graph.nodes.html
513    pub fn nodes(&self) -> Nodes<N, E> {
514        Nodes::new(self)
515    }
516
517    /// An Edge view of the Graph.
518    ///
519    /// # Reference
520    ///
521    /// * https://networkx.github.io/documentation/stable/reference/classes/generated/networkx.Graph.edges.html
522    pub fn edges(&self) -> Edges<N, E> {
523        Edges::new(self)
524    }
525}
526// pub:1 ends here
527
528// [[file:../nxgraph.note::*test][test:1]]
529#[cfg(test)]
530mod test {
531    use super::*;
532
533    #[derive(Clone, Default, Debug, PartialEq)]
534    struct Edge {
535        weight: f64,
536    }
537
538    impl Edge {
539        fn new(weight: f64) -> Self {
540            Self { weight }
541        }
542    }
543
544    #[derive(Clone, Default, Debug, PartialEq)]
545    struct Node {
546        /// The Cartesian position of this `Node`.
547        position: [f64; 3],
548    }
549
550    #[test]
551    fn test_graph() {
552        // add and remove nodes
553        let mut g = NxGraph::new();
554        let n1 = g.add_node(Node::default());
555        let n2 = g.add_node(Node::default());
556        let n3 = g.add_node(Node::default());
557
558        // add edges
559        g.add_edge(n1, n2, Edge { weight: 1.0 });
560        assert_eq!(1, g.number_of_edges());
561
562        // add edge n1-n2 again. Note: no parallel edge
563        g.add_edge(n1, n2, Edge { weight: 2.0 });
564        assert_eq!(1, g.number_of_edges());
565        // edge data has been udpated
566        assert_eq!(g[(n1, n2)].weight, 2.0);
567
568        g.add_edge(n1, n3, Edge::default());
569        let n4 = g.add_node(Node::default());
570        let _ = g.remove_node(n4);
571        assert_eq!(g.number_of_nodes(), 3);
572        assert_eq!(g.number_of_edges(), 2);
573
574        // test remove node and edge
575        let node = Node { position: [1.0; 3] };
576        let n4 = g.add_node(node.clone());
577        let edge = Edge { weight: 2.2 };
578        g.add_edge(n1, n4, edge.clone());
579        let x = g.remove_edge(n2, n4);
580        assert_eq!(x, None);
581        let x = g.remove_edge(n1, n4);
582        assert_eq!(x, Some(edge));
583        let x = g.remove_node(n4);
584        assert_eq!(x, Some(node));
585
586        // test graph
587        assert!(g.has_node(n1));
588        assert!(g.has_edge(n1, n2));
589        assert!(!g.has_edge(n2, n3));
590        let _ = g.remove_edge(n1, n3);
591        assert_eq!(g.number_of_edges(), 1);
592        assert!(!g.has_edge(n1, n3));
593
594        // edit node attributes
595        g[n1].position = [1.9; 3];
596
597        // node view
598        let nodes = g.nodes();
599        assert_eq!(nodes[n1].position, [1.9; 3]);
600
601        // edit edge attributes
602        g[(n1, n2)].weight = 0.3;
603
604        // edge view
605        let edges = g.edges();
606        assert_eq!(edges[(n1, n2)].weight, 0.3);
607        assert_eq!(edges[(n2, n1)].weight, 0.3);
608
609        // loop over nodes
610        for (u, node_data) in g.nodes() {
611            dbg!(u, node_data);
612        }
613
614        // loop over edges
615        for (u, v, edge_data) in g.edges() {
616            dbg!(u, v, edge_data);
617        }
618
619        // loop over neighbors of node `n1`
620        for u in g.neighbors(n1) {
621            dbg!(&g[u]);
622        }
623
624        // clear graph
625        g.clear();
626        assert_eq!(g.number_of_nodes(), 0);
627        assert_eq!(g.number_of_edges(), 0);
628    }
629
630    #[test]
631    #[should_panic]
632    fn test_speical_graph() {
633        let mut g = NxGraph::new();
634        let n1 = g.add_node(Node::default());
635        let n2 = g.add_node(Node::default());
636
637        g.add_edge(n1, n2, Edge::new(1.0));
638        assert_eq!(g[(n1, n2)].weight, 1.0);
639        assert_eq!(g[(n2, n1)].weight, 1.0);
640
641        // parallel edge is avoided
642        g.add_edge(n2, n1, Edge::new(2.0));
643        assert_eq!(g[(n1, n2)].weight, 2.0);
644
645        // self-loop is not allowed
646        g.add_edge(n2, n2, Edge::default());
647    }
648}
649// test:1 ends here