duskphantom_graph/
lib.rs

1// Copyright 2024 Duskphantom Authors
2//
3// Licensed under the Apache License, Version 2.0 (the "License");
4// you may not use this file except in compliance with the License.
5// You may obtain a copy of the License at
6//
7//     http://www.apache.org/licenses/LICENSE-2.0
8//
9// Unless required by applicable law or agreed to in writing, software
10// distributed under the License is distributed on an "AS IS" BASIS,
11// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12// See the License for the specific language governing permissions and
13// limitations under the License.
14//
15// SPDX-License-Identifier: Apache-2.0
16
17use std::collections::{HashMap, HashSet};
18
19pub use anyhow::anyhow;
20pub use anyhow::Result;
21
22pub trait GraphNode: std::hash::Hash + std::cmp::Eq + Sized + Clone {}
23
24impl<T> GraphNode for T where T: std::hash::Hash + std::cmp::Eq + Sized + Clone {}
25
26pub trait GraphNodeFromStr: GraphNode {
27    /// Parse the node from a str
28    fn from_str(input: &str) -> Result<Self>
29    where
30        Self: Sized;
31}
32
33impl<T> GraphNodeFromStr for T
34where
35    T: std::str::FromStr + GraphNode,
36{
37    fn from_str(input: &str) -> Result<Self> {
38        input.parse().map_err(|_| anyhow!("parse error"))
39    }
40}
41
42/// Undirected graph
43pub struct UdGraph<T: GraphNode> {
44    id_alloc: u64,
45    n_id: std::collections::HashMap<T, u64>,
46    id_n: std::collections::HashMap<u64, T>,
47    edges: std::collections::HashMap<u64, std::collections::HashSet<u64>>,
48}
49
50impl<T: GraphNode> UdGraph<T> {
51    #[allow(clippy::new_without_default)]
52    pub fn new() -> Self {
53        Self {
54            id_alloc: 0,
55            n_id: std::collections::HashMap::new(),
56            id_n: std::collections::HashMap::new(),
57            edges: std::collections::HashMap::new(),
58        }
59    }
60
61    pub fn add_node(&mut self, node: T) -> u64 {
62        if let Some(id) = self.n_id.get(&node) {
63            return *id;
64        }
65        let id = self.id_alloc;
66        assert!(id < u64::MAX);
67        self.id_alloc += 1;
68        self.n_id.insert(node.clone(), id);
69        self.id_n.insert(id, node);
70
71        self.edges.entry(id).or_default();
72
73        id
74    }
75
76    /// if any node is not in the graph, it will be added to the graph.
77    /// the self to self edge adding will be ignored
78    pub fn add_edge(&mut self, from: T, to: T) {
79        let from_id = self.add_node(from);
80        let to_id = self.add_node(to);
81        self._add_edge_by_id(from_id, to_id)
82    }
83
84    pub fn add_edge_ref(&mut self, from: &T, to: &T) {
85        let from_id = self.add_node(from.clone());
86        let to_id = self.add_node(to.clone());
87        self._add_edge_by_id(from_id, to_id)
88    }
89
90    #[inline]
91    fn _add_edge_by_id(&mut self, from_id: u64, to_id: u64) {
92        if from_id == to_id {
93            return;
94        }
95        self.edges.entry(from_id).or_default().insert(to_id);
96        self.edges.entry(to_id).or_default().insert(from_id);
97    }
98
99    fn get_node(&self, id: u64) -> Option<&T> {
100        self.id_n.get(&id)
101    }
102
103    pub fn nodes(&self) -> std::collections::hash_map::Values<u64, T> {
104        self.id_n.values()
105    }
106
107    pub fn iter(&self) -> UdGraphIter<T> {
108        UdGraphIter {
109            graph: self,
110            nodes_iter: self.id_n.keys(),
111        }
112    }
113
114    pub fn is_empty(&self) -> bool {
115        self.id_n.is_empty()
116    }
117
118    /// get neighbor nodes of the node `from`
119    pub fn get_nbs<'a>(&'a self, from: &T) -> Option<Neighbors<'a, T>> {
120        let id = self.n_id.get(from)?;
121        let tos = self.edges.get(id)?;
122        let tos_iter = tos.iter();
123        Some(Neighbors {
124            graph: self,
125            tos,
126            tos_iter,
127        })
128    }
129}
130
131impl<T: GraphNode> UdGraph<T> {
132    pub fn gen_dot(&self, graph_name: &str, node_shower: impl Fn(&T) -> String) -> String {
133        let mut res = String::new();
134        res.push_str(&format!("graph {} {{\n", graph_name));
135
136        let mut showed: HashSet<(u64, u64)> = HashSet::new();
137        let mut sorted_edges: Vec<(&u64, &HashSet<u64>)> = self.edges.iter().collect();
138        sorted_edges.sort_by_key(|(k, _)| **k);
139        for (k, nbs) in sorted_edges {
140            let from = self.get_node(*k).unwrap();
141            let from_str = node_shower(from);
142            let mut sorted_nbs: Vec<&u64> = nbs.iter().collect();
143            sorted_nbs.sort_by_key(|&&x| x);
144            for to in sorted_nbs {
145                if showed.contains(&(*k, *to)) || showed.contains(&(*to, *k)) {
146                    continue;
147                }
148                showed.insert((*k, *to));
149                let to = self.get_node(*to).unwrap();
150                let to_str = node_shower(to);
151                res.push_str(&format!("{} -- {};\n", from_str, to_str));
152            }
153        }
154        res.push_str("}\n");
155        res
156    }
157}
158
159pub struct UdGraphIter<'a, T: GraphNode> {
160    graph: &'a UdGraph<T>,
161    nodes_iter: std::collections::hash_map::Keys<'a, u64, T>,
162}
163
164impl<'a, T: GraphNode> Iterator for UdGraphIter<'a, T> {
165    type Item = (&'a T, Neighbors<'a, T>);
166    fn next(&mut self) -> Option<Self::Item> {
167        self.nodes_iter.next().map(|id| {
168            let node = self.graph.get_node(*id).unwrap();
169            let nbs = self.graph.get_nbs(node).unwrap();
170            (node, nbs)
171        })
172    }
173}
174
175pub struct Neighbors<'a, T: GraphNode> {
176    graph: &'a UdGraph<T>,
177    tos: &'a std::collections::HashSet<u64>,
178    tos_iter: std::collections::hash_set::Iter<'a, u64>,
179}
180/// impl Iterator for ToNodes
181impl<'a, T: GraphNode> Iterator for Neighbors<'a, T> {
182    type Item = &'a T;
183    fn next(&mut self) -> Option<Self::Item> {
184        self.tos_iter
185            .next()
186            .map(|id| self.graph.get_node(*id).unwrap())
187    }
188}
189
190impl<T: GraphNode> Neighbors<'_, T> {
191    pub fn contains(&self, node: &T) -> bool {
192        self.graph
193            .n_id
194            .get(node)
195            .map_or(false, |id| self.tos.contains(id))
196    }
197}
198
199impl<T: GraphNode> From<HashMap<T, HashSet<T>>> for UdGraph<T> {
200    fn from(value: HashMap<T, HashSet<T>>) -> Self {
201        let mut g = UdGraph::new();
202        for (k, v) in value {
203            for to in v {
204                g.add_edge(k.clone(), to);
205            }
206        }
207        g
208    }
209}
210impl<T: GraphNode> From<HashSet<(T, T)>> for UdGraph<T> {
211    fn from(value: HashSet<(T, T)>) -> Self {
212        let mut g = UdGraph::new();
213        for (k, v) in value {
214            g.add_edge(k, v);
215        }
216        g
217    }
218}
219
220impl<T: GraphNode> From<UdGraph<T>> for HashMap<T, HashSet<T>> {
221    fn from(g: UdGraph<T>) -> Self {
222        let mut res = HashMap::new();
223        for (k, v) in g.edges.iter() {
224            let k = g.get_node(*k).unwrap().clone();
225            let mut vs = HashSet::new();
226            for to in v {
227                vs.insert(g.get_node(*to).unwrap().clone());
228            }
229            res.insert(k, vs);
230        }
231        res
232    }
233}
234
235#[macro_export]
236/// a macro to create a graph
237/// # Example
238/// ```rust
239/// use duskphantom_graph::*;
240/// let g: UdGraph<u32> = udgraph!(
241///    {1 -> 2,3},
242///   {2 -> 3}
243/// ).unwrap();
244/// ```
245/// or
246/// ```rust
247/// use duskphantom_graph::*;
248/// let g: UdGraph<u32> = udgraph!(u32; {1 -> 2,3}, {2 -> 3}).unwrap();
249/// ```
250macro_rules! udgraph {
251    ($({$key:tt $sep:tt $($tos:tt),*}$(,)?)*) => {{
252        let parse_graph=||->anyhow::Result<$crate::UdGraph<_>>{
253            let mut g=$crate::UdGraph::new();
254            $(
255                $(
256                    let k=$crate::GraphNodeFromStr::from_str(&stringify!($key))?;
257                    let v=$crate::GraphNodeFromStr::from_str(&stringify!($tos))?;
258                    g.add_edge(k,v);
259                )*
260            )*
261            Ok(g)
262        };
263        parse_graph()
264    }};
265    ($n_ty:ty;$({$key:tt $sep:tt $($tos:tt),*}$(,)?)*) => {{
266        let parse_graph=||->$crate::Result<$crate::UdGraph<$n_ty>>{
267            let mut g=$crate::UdGraph::new();
268            $(
269                $(
270                    let k:$n_ty=$crate::GraphNodeFromStr::from_str(stringify!($key))?;
271                    let v:$n_ty=$crate::GraphNodeFromStr::from_str(stringify!($tos))?;
272                    g.add_edge(k,v);
273                )*
274            )*
275            Ok(g)
276        };
277        parse_graph()
278    }};
279}
280
281#[cfg(test)]
282mod tests {
283
284    use super::*;
285    #[test]
286    fn basic() {
287        let mut g = UdGraph::<u32>::new();
288        g.add_edge(1, 2);
289        g.add_edge(1, 3);
290        g.add_edge(2, 3);
291
292        let mut ns = g.nodes().collect::<Vec<&u32>>();
293        ns.sort();
294        assert_eq!(ns, vec![&1, &2, &3]);
295
296        let mut nbs: Vec<&u32> = g.get_nbs(&1).unwrap().collect();
297        nbs.sort();
298        assert_eq!(nbs, vec![&2, &3]);
299
300        let mut nbs: Vec<&u32> = g.get_nbs(&2).unwrap().collect();
301        nbs.sort();
302        assert_eq!(nbs, vec![&1, &3]);
303    }
304
305    #[test]
306    fn test_macro() {
307        let g: UdGraph<u32> = udgraph!(
308            {1 -> 2,3},
309            {2 -> 3}
310        )
311        .unwrap();
312        assert!(!g.is_empty());
313        let hm: HashMap<u32, HashSet<u32>> = g.into();
314        assert_eq!(hm.len(), 3);
315        assert_eq!(hm.get(&1).unwrap().len(), 2);
316        assert_eq!(hm.get(&2).unwrap().len(), 2);
317
318        let g = udgraph!(u32; {1 -> 2,3}, {2 -> 3}).unwrap();
319        assert!(!g.is_empty());
320        let hm: HashMap<u32, HashSet<u32>> = g.into();
321        assert_eq!(hm.len(), 3);
322        assert_eq!(hm.get(&1).unwrap().len(), 2);
323        assert_eq!(hm.get(&2).unwrap().len(), 2);
324    }
325
326    #[test]
327    // the self to self edge adding will be ignored
328    fn test_self_to_self() {
329        let g: UdGraph<u32> = udgraph!({1 -> 1}).unwrap();
330        assert!(!g.is_empty());
331        let hm: HashMap<u32, HashSet<u32>> = g.into();
332        assert_eq!(hm.len(), 1);
333        assert_eq!(hm.get(&1).unwrap().len(), 0);
334    }
335}