pulsar_utils/
digraph.rs

1// Copyright (C) 2024 Ethan Uppal. All rights reserved.
2use std::{
3    collections::{HashMap, HashSet},
4    hash::Hash
5};
6
7pub struct Digraph<Node: Hash + Eq, Edge> {
8    adj: HashMap<Node, Vec<(Edge, Node)>>
9}
10
11impl<Node: Hash + Eq, Edge> Digraph<Node, Edge> {
12    pub fn new() -> Self {
13        Self {
14            adj: HashMap::new()
15        }
16    }
17
18    pub fn add_node(&mut self, node: Node) {
19        self.adj.insert(node, vec![]);
20    }
21
22    pub fn add_edge(&mut self, u: Node, e: Edge, v: Node) {
23        if let Some(out) = self.adj.get_mut(&u) {
24            out.push((e, v));
25        }
26    }
27
28    pub fn out_of(&self, node: Node) -> Option<&Vec<(Edge, Node)>> {
29        self.adj.get(&node)
30    }
31
32    pub fn nodes(&self) -> Vec<&Node> {
33        self.adj.keys().collect()
34    }
35
36    pub fn node_count(&self) -> usize {
37        self.adj.len()
38    }
39}
40
41impl<Node: Hash + Eq + Clone, Edge> Digraph<Node, Edge> {
42    /// Conducts a depth-first search (DFS) starting from `start`, calling `f`
43    /// on each node encountered in DFS order. This function performs multiple
44    /// `clone()`s of the nodes, which is still performant when nodes are
45    /// smart pointers such as `Rc`.
46    ///
47    /// Requires: `start` is in the graph.
48    pub fn dfs<F>(&self, mut f: F, start: Node)
49    where
50        F: FnMut(Node) {
51        assert!(self.adj.contains_key(&start));
52
53        let mut visited = HashSet::new();
54        let mut stack = vec![];
55
56        stack.push(start);
57        while let Some(node) = stack.pop() {
58            visited.insert(node.clone());
59            f(node.clone());
60            for (_, next) in self.out_of(node).unwrap() {
61                stack.push(next.clone());
62            }
63        }
64    }
65}