1use std::collections::{HashMap, HashSet};
18
19pub use anyhow::anyhow;
20pub use anyhow::Result;
21
22pub trait GraphNode: std::hash::Hash + std::cmp::Eq + Sized + Clone {}
23
24impl<T> GraphNode for T where T: std::hash::Hash + std::cmp::Eq + Sized + Clone {}
25
26pub trait GraphNodeFromStr: GraphNode {
27 fn from_str(input: &str) -> Result<Self>
29 where
30 Self: Sized;
31}
32
33impl<T> GraphNodeFromStr for T
34where
35 T: std::str::FromStr + GraphNode,
36{
37 fn from_str(input: &str) -> Result<Self> {
38 input.parse().map_err(|_| anyhow!("parse error"))
39 }
40}
41
42pub struct UdGraph<T: GraphNode> {
44 id_alloc: u64,
45 n_id: std::collections::HashMap<T, u64>,
46 id_n: std::collections::HashMap<u64, T>,
47 edges: std::collections::HashMap<u64, std::collections::HashSet<u64>>,
48}
49
50impl<T: GraphNode> UdGraph<T> {
51 #[allow(clippy::new_without_default)]
52 pub fn new() -> Self {
53 Self {
54 id_alloc: 0,
55 n_id: std::collections::HashMap::new(),
56 id_n: std::collections::HashMap::new(),
57 edges: std::collections::HashMap::new(),
58 }
59 }
60
61 pub fn add_node(&mut self, node: T) -> u64 {
62 if let Some(id) = self.n_id.get(&node) {
63 return *id;
64 }
65 let id = self.id_alloc;
66 assert!(id < u64::MAX);
67 self.id_alloc += 1;
68 self.n_id.insert(node.clone(), id);
69 self.id_n.insert(id, node);
70
71 self.edges.entry(id).or_default();
72
73 id
74 }
75
76 pub fn add_edge(&mut self, from: T, to: T) {
79 let from_id = self.add_node(from);
80 let to_id = self.add_node(to);
81 self._add_edge_by_id(from_id, to_id)
82 }
83
84 pub fn add_edge_ref(&mut self, from: &T, to: &T) {
85 let from_id = self.add_node(from.clone());
86 let to_id = self.add_node(to.clone());
87 self._add_edge_by_id(from_id, to_id)
88 }
89
90 #[inline]
91 fn _add_edge_by_id(&mut self, from_id: u64, to_id: u64) {
92 if from_id == to_id {
93 return;
94 }
95 self.edges.entry(from_id).or_default().insert(to_id);
96 self.edges.entry(to_id).or_default().insert(from_id);
97 }
98
99 fn get_node(&self, id: u64) -> Option<&T> {
100 self.id_n.get(&id)
101 }
102
103 pub fn nodes(&self) -> std::collections::hash_map::Values<u64, T> {
104 self.id_n.values()
105 }
106
107 pub fn iter(&self) -> UdGraphIter<T> {
108 UdGraphIter {
109 graph: self,
110 nodes_iter: self.id_n.keys(),
111 }
112 }
113
114 pub fn is_empty(&self) -> bool {
115 self.id_n.is_empty()
116 }
117
118 pub fn get_nbs<'a>(&'a self, from: &T) -> Option<Neighbors<'a, T>> {
120 let id = self.n_id.get(from)?;
121 let tos = self.edges.get(id)?;
122 let tos_iter = tos.iter();
123 Some(Neighbors {
124 graph: self,
125 tos,
126 tos_iter,
127 })
128 }
129}
130
131impl<T: GraphNode> UdGraph<T> {
132 pub fn gen_dot(&self, graph_name: &str, node_shower: impl Fn(&T) -> String) -> String {
133 let mut res = String::new();
134 res.push_str(&format!("graph {} {{\n", graph_name));
135
136 let mut showed: HashSet<(u64, u64)> = HashSet::new();
137 let mut sorted_edges: Vec<(&u64, &HashSet<u64>)> = self.edges.iter().collect();
138 sorted_edges.sort_by_key(|(k, _)| **k);
139 for (k, nbs) in sorted_edges {
140 let from = self.get_node(*k).unwrap();
141 let from_str = node_shower(from);
142 let mut sorted_nbs: Vec<&u64> = nbs.iter().collect();
143 sorted_nbs.sort_by_key(|&&x| x);
144 for to in sorted_nbs {
145 if showed.contains(&(*k, *to)) || showed.contains(&(*to, *k)) {
146 continue;
147 }
148 showed.insert((*k, *to));
149 let to = self.get_node(*to).unwrap();
150 let to_str = node_shower(to);
151 res.push_str(&format!("{} -- {};\n", from_str, to_str));
152 }
153 }
154 res.push_str("}\n");
155 res
156 }
157}
158
159pub struct UdGraphIter<'a, T: GraphNode> {
160 graph: &'a UdGraph<T>,
161 nodes_iter: std::collections::hash_map::Keys<'a, u64, T>,
162}
163
164impl<'a, T: GraphNode> Iterator for UdGraphIter<'a, T> {
165 type Item = (&'a T, Neighbors<'a, T>);
166 fn next(&mut self) -> Option<Self::Item> {
167 self.nodes_iter.next().map(|id| {
168 let node = self.graph.get_node(*id).unwrap();
169 let nbs = self.graph.get_nbs(node).unwrap();
170 (node, nbs)
171 })
172 }
173}
174
175pub struct Neighbors<'a, T: GraphNode> {
176 graph: &'a UdGraph<T>,
177 tos: &'a std::collections::HashSet<u64>,
178 tos_iter: std::collections::hash_set::Iter<'a, u64>,
179}
180impl<'a, T: GraphNode> Iterator for Neighbors<'a, T> {
182 type Item = &'a T;
183 fn next(&mut self) -> Option<Self::Item> {
184 self.tos_iter
185 .next()
186 .map(|id| self.graph.get_node(*id).unwrap())
187 }
188}
189
190impl<T: GraphNode> Neighbors<'_, T> {
191 pub fn contains(&self, node: &T) -> bool {
192 self.graph
193 .n_id
194 .get(node)
195 .map_or(false, |id| self.tos.contains(id))
196 }
197}
198
199impl<T: GraphNode> From<HashMap<T, HashSet<T>>> for UdGraph<T> {
200 fn from(value: HashMap<T, HashSet<T>>) -> Self {
201 let mut g = UdGraph::new();
202 for (k, v) in value {
203 for to in v {
204 g.add_edge(k.clone(), to);
205 }
206 }
207 g
208 }
209}
210impl<T: GraphNode> From<HashSet<(T, T)>> for UdGraph<T> {
211 fn from(value: HashSet<(T, T)>) -> Self {
212 let mut g = UdGraph::new();
213 for (k, v) in value {
214 g.add_edge(k, v);
215 }
216 g
217 }
218}
219
220impl<T: GraphNode> From<UdGraph<T>> for HashMap<T, HashSet<T>> {
221 fn from(g: UdGraph<T>) -> Self {
222 let mut res = HashMap::new();
223 for (k, v) in g.edges.iter() {
224 let k = g.get_node(*k).unwrap().clone();
225 let mut vs = HashSet::new();
226 for to in v {
227 vs.insert(g.get_node(*to).unwrap().clone());
228 }
229 res.insert(k, vs);
230 }
231 res
232 }
233}
234
235#[macro_export]
236macro_rules! udgraph {
251 ($({$key:tt $sep:tt $($tos:tt),*}$(,)?)*) => {{
252 let parse_graph=||->anyhow::Result<$crate::UdGraph<_>>{
253 let mut g=$crate::UdGraph::new();
254 $(
255 $(
256 let k=$crate::GraphNodeFromStr::from_str(&stringify!($key))?;
257 let v=$crate::GraphNodeFromStr::from_str(&stringify!($tos))?;
258 g.add_edge(k,v);
259 )*
260 )*
261 Ok(g)
262 };
263 parse_graph()
264 }};
265 ($n_ty:ty;$({$key:tt $sep:tt $($tos:tt),*}$(,)?)*) => {{
266 let parse_graph=||->$crate::Result<$crate::UdGraph<$n_ty>>{
267 let mut g=$crate::UdGraph::new();
268 $(
269 $(
270 let k:$n_ty=$crate::GraphNodeFromStr::from_str(stringify!($key))?;
271 let v:$n_ty=$crate::GraphNodeFromStr::from_str(stringify!($tos))?;
272 g.add_edge(k,v);
273 )*
274 )*
275 Ok(g)
276 };
277 parse_graph()
278 }};
279}
280
281#[cfg(test)]
282mod tests {
283
284 use super::*;
285 #[test]
286 fn basic() {
287 let mut g = UdGraph::<u32>::new();
288 g.add_edge(1, 2);
289 g.add_edge(1, 3);
290 g.add_edge(2, 3);
291
292 let mut ns = g.nodes().collect::<Vec<&u32>>();
293 ns.sort();
294 assert_eq!(ns, vec![&1, &2, &3]);
295
296 let mut nbs: Vec<&u32> = g.get_nbs(&1).unwrap().collect();
297 nbs.sort();
298 assert_eq!(nbs, vec![&2, &3]);
299
300 let mut nbs: Vec<&u32> = g.get_nbs(&2).unwrap().collect();
301 nbs.sort();
302 assert_eq!(nbs, vec![&1, &3]);
303 }
304
305 #[test]
306 fn test_macro() {
307 let g: UdGraph<u32> = udgraph!(
308 {1 -> 2,3},
309 {2 -> 3}
310 )
311 .unwrap();
312 assert!(!g.is_empty());
313 let hm: HashMap<u32, HashSet<u32>> = g.into();
314 assert_eq!(hm.len(), 3);
315 assert_eq!(hm.get(&1).unwrap().len(), 2);
316 assert_eq!(hm.get(&2).unwrap().len(), 2);
317
318 let g = udgraph!(u32; {1 -> 2,3}, {2 -> 3}).unwrap();
319 assert!(!g.is_empty());
320 let hm: HashMap<u32, HashSet<u32>> = g.into();
321 assert_eq!(hm.len(), 3);
322 assert_eq!(hm.get(&1).unwrap().len(), 2);
323 assert_eq!(hm.get(&2).unwrap().len(), 2);
324 }
325
326 #[test]
327 fn test_self_to_self() {
329 let g: UdGraph<u32> = udgraph!({1 -> 1}).unwrap();
330 assert!(!g.is_empty());
331 let hm: HashMap<u32, HashSet<u32>> = g.into();
332 assert_eq!(hm.len(), 1);
333 assert_eq!(hm.get(&1).unwrap().len(), 0);
334 }
335}