1#[derive(Debug)]
2pub struct Node<T> {
3 pub id: usize,
4 pub data: T,
5}
6
7impl<T> Node<T> {
8 pub fn new(id: usize, data: T) -> Self {
9 Self { id, data }
10 }
11}
12
13#[derive(Debug)]
14pub enum GraphBuildingError {
15 NodeNotFound(usize),
16 NodeAlreadyExists,
17 EdgeAlreadyExists,
18}
19
20pub struct DirectedGraph<T> {
24 nodes: Vec<Node<T>>,
25 adjacency_list: Vec<Vec<usize>>,
27 reverse_adjacency_list: Vec<Vec<usize>>,
29}
30
31impl<T> DirectedGraph<T> {
32 pub fn new() -> Self {
33 DirectedGraph {
34 nodes: vec![],
35 adjacency_list: vec![],
36 reverse_adjacency_list: vec![],
37 }
38 }
39
40 pub fn add_node(&mut self, data: T) -> Result<usize, GraphBuildingError> {
44 let id = self.node_count();
45 if self.get_node(id).is_some() {
46 return Err(GraphBuildingError::NodeAlreadyExists);
47 }
48 self.nodes.push(Node::new(id, data));
49 self.adjacency_list.push(vec![]);
50 self.reverse_adjacency_list.push(vec![]);
51 Ok(id)
52 }
53
54 pub fn add_edge(
58 &mut self,
59 source_id: usize,
60 target_id: usize,
61 ) -> Result<(), GraphBuildingError> {
62 if source_id > self.nodes.len() {
64 return Err(GraphBuildingError::NodeNotFound(source_id));
65 }
66 if target_id > self.nodes.len() {
67 return Err(GraphBuildingError::NodeNotFound(target_id));
68 }
69 if self.adjacency_list[source_id].contains(&target_id) {
70 return Err(GraphBuildingError::EdgeAlreadyExists);
71 }
72
73 self.adjacency_list[source_id].push(target_id);
75 self.reverse_adjacency_list[target_id].push(source_id);
76 Ok(())
77 }
78
79 pub fn get_node(&self, id: usize) -> Option<&Node<T>> {
80 self.nodes.get(id)
81 }
82
83 pub fn get_node_id_with<F>(&self, f: F) -> Option<usize>
84 where
85 F: Fn(&T) -> bool,
86 {
87 self.nodes.iter().position(|node| f(&node.data))
88 }
89
90 pub fn get_all_node_ids_with<F>(&self, f: F) -> Vec<usize>
91 where
92 F: Fn(&T) -> bool,
93 {
94 self.nodes
95 .iter()
96 .filter(|node| f(&node.data))
97 .map(|node| node.id)
98 .collect()
99 }
100
101 pub fn get_node_mut(&mut self, id: usize) -> Option<&mut Node<T>> {
102 self.nodes.get_mut(id)
103 }
104
105 pub fn get_children(&self, id: usize) -> Option<&[usize]> {
106 self.adjacency_list
107 .get(id)
108 .map(|indices| indices.as_slice())
109 }
110
111 pub fn get_parents(&self, id: usize) -> Option<&[usize]> {
112 self.reverse_adjacency_list
113 .get(id)
114 .map(|indices| indices.as_slice())
115 }
116
117 pub fn get_bfs(&self, root_id: usize, reverse: bool) -> Vec<usize> {
118 let adjacency = if reverse {
119 &self.reverse_adjacency_list
120 } else {
121 &self.adjacency_list
122 };
123 let node_count = self.node_count();
124 let mut visited = vec![false; node_count];
125 let mut queue = vec![root_id];
126 let mut bfs = Vec::<usize>::new();
127 visited[root_id] = true;
128 while queue.len() > 0 {
129 let node = queue.pop().unwrap();
130 bfs.push(node);
131 for id in 0..node_count {
132 if adjacency[node].contains(&id) && !visited[id] {
133 queue.push(id);
134 visited[id] = true;
135 }
136 }
137 }
138 if reverse {
139 bfs.reverse()
140 }
141 bfs.pop();
142
143 bfs
144 }
145
146 pub fn node_count(&self) -> usize {
147 self.nodes.len()
148 }
149
150 pub fn is_root(&self, id: usize) -> bool {
151 self.reverse_adjacency_list.get(id).unwrap().len() == 0
152 }
153
154 pub fn is_leaf(&self, id: usize) -> bool {
155 self.adjacency_list.get(id).unwrap().len() == 0
156 }
157}
158
159impl<T> DirectedGraph<T> {
160 pub fn map_topology_with_default<U: Default>(&self) -> DirectedGraph<U> {
161 let num_nodes = self.node_count();
162 let mut g = DirectedGraph::<U>::new();
163 for _ in self.nodes.iter() {
164 g.add_node(U::default()).unwrap();
165 }
166 for source_id in 0..num_nodes {
167 if let Some(children_ids) = self.adjacency_list.get(source_id) {
168 for &target_id in children_ids.iter() {
169 g.add_edge(source_id, target_id).unwrap();
170 }
171 }
172 }
173 g
174 }
175
176 pub fn map_topology_with<U, F>(&self, mut f: F) -> DirectedGraph<U>
177 where
178 F: FnMut(&T, usize) -> U,
179 {
180 let num_nodes = self.node_count();
181 let mut g = DirectedGraph::<U>::new();
182 for node in self.nodes.iter() {
183 g.add_node(f(&node.data, node.id)).unwrap();
184 }
185 for source_id in 0..num_nodes {
186 if let Some(children_ids) = self.adjacency_list.get(source_id) {
187 for &target_id in children_ids.iter() {
188 g.add_edge(source_id, target_id).unwrap();
189 }
190 }
191 }
192 g
193 }
194}
195
196#[cfg(test)]
197mod tests {
198
199 use super::*;
200
201 #[test]
202 fn test_create_directed_graph() {
203 let graph = DirectedGraph::<f64>::new();
204 assert_eq!(graph.node_count(), 0);
205 }
206
207 #[test]
208 fn test_add_node_to_directed_graph() {
209 let mut graph = DirectedGraph::<f64>::new();
210 graph.add_node(10.0).unwrap();
211 graph.add_node(20.0).unwrap();
212 assert_eq!(graph.node_count(), 2);
213 }
214
215 #[test]
216 fn test_add_edge_to_directed_graph() {
217 let mut graph = DirectedGraph::<f64>::new();
218 graph.add_node(10.0).unwrap();
219 graph.add_node(20.0).unwrap();
220 let edge_add_status = graph.add_edge(0, 1);
221 assert!(edge_add_status.is_ok())
222 }
223
224 #[test]
225 fn test_map_topology_with_default() {
226 let mut graph = DirectedGraph::<f64>::new();
227 graph.add_node(10.0).unwrap();
228 graph.add_node(20.0).unwrap();
229 let edge_add_status = graph.add_edge(0, 1);
230 assert!(edge_add_status.is_ok());
231 let new_graph = graph.map_topology_with_default::<usize>();
232 assert_eq!(new_graph.node_count(), 2);
233 assert!(new_graph.get_node(0).is_some());
234 assert!(new_graph.get_node(1).is_some());
235 assert!(new_graph.is_root(0));
236 assert!(new_graph.is_leaf(1));
237 }
238
239 #[test]
240 fn test_map_topology_with() {
241 let mut graph = DirectedGraph::<f64>::new();
242 graph.add_node(10.0).unwrap();
243 graph.add_node(20.0).unwrap();
244 let edge_add_status = graph.add_edge(0, 1);
245 assert!(edge_add_status.is_ok());
246 let new_graph = graph.map_topology_with(|value, _id| *value as usize);
247 assert_eq!(new_graph.node_count(), 2);
248 assert!(new_graph.get_node(0).is_some());
249 assert!(new_graph.get_node(1).is_some());
250 assert!(new_graph.is_root(0));
251 assert!(new_graph.is_leaf(1));
252 }
253}