1use std::cmp;
8
9#[derive(Debug)]
11pub struct DAG {
12 nodes: Vec<Node>,
14
15 ranks: RankType,
17
18 validate: bool,
20}
21
22#[derive(Copy, Clone, Default, PartialEq, PartialOrd, Eq, Ord, Hash, Debug)]
24pub struct NodeHandle {
25 idx: usize,
26}
27
28impl NodeHandle {
29 pub fn new(x: usize) -> Self {
30 NodeHandle { idx: x }
31 }
32 pub fn get_index(&self) -> usize {
33 self.idx
34 }
35}
36
37impl From<usize> for NodeHandle {
38 fn from(idx: usize) -> Self {
39 NodeHandle { idx }
40 }
41}
42
43#[derive(Debug)]
44struct Node {
45 successors: Vec<NodeHandle>,
47 predecessors: Vec<NodeHandle>,
48}
49
50pub type RankType = Vec<Vec<NodeHandle>>;
51
52impl Node {
53 pub fn new() -> Self {
54 Node {
55 successors: Vec::new(),
56 predecessors: Vec::new(),
57 }
58 }
59}
60
61#[derive(Debug)]
63pub struct NodeIterator {
64 curr: usize,
65 last: usize,
66}
67
68impl Iterator for NodeIterator {
69 type Item = NodeHandle;
70
71 fn next(&mut self) -> Option<Self::Item> {
72 if self.curr == self.last {
73 return None;
74 }
75
76 let item = Some(NodeHandle::from(self.curr));
77 self.curr += 1;
78 item
79 }
80}
81
82impl DAG {
83 pub fn new() -> Self {
84 DAG {
85 nodes: Vec::new(),
86 ranks: Vec::new(),
87 validate: true,
88 }
89 }
90
91 pub fn set_validate(&mut self, validate: bool) {
92 self.validate = validate;
93 }
94
95 pub fn clear(&mut self) {
96 self.nodes.clear();
97 self.ranks.clear();
98 }
99
100 pub fn iter(&self) -> NodeIterator {
101 NodeIterator {
102 curr: 0,
103 last: self.nodes.len(),
104 }
105 }
106
107 pub fn add_edge(&mut self, from: NodeHandle, to: NodeHandle) {
108 self.nodes[from.idx].successors.push(to);
109 self.nodes[to.idx].predecessors.push(from);
110 }
111
112 pub fn remove_edge(&mut self, from: NodeHandle, to: NodeHandle) -> bool {
115 let succ = &mut self.nodes[from.idx].successors;
116 let mut removed_succ = false;
117
118 if let Some(pos) = succ.iter().position(|x| *x == to) {
119 succ.remove(pos);
120 removed_succ = true;
121 }
122
123 let pred = &mut self.nodes[to.idx].predecessors;
124 let mut removed_pred = false;
125 if let Some(pos) = pred.iter().position(|x| *x == from) {
126 pred.remove(pos);
127 removed_pred = true;
128 }
129
130 assert_eq!(removed_pred, removed_succ);
133 removed_pred
134 }
135
136 pub fn new_node(&mut self) -> NodeHandle {
138 self.nodes.push(Node::new());
139 let node = NodeHandle::new(self.nodes.len() - 1);
140 self.add_element_to_rank(node, 0, false);
141 node
142 }
143
144 pub fn new_nodes(&mut self, n: usize) {
146 for _ in 0..n {
147 self.nodes.push(Node::new());
148 let node = NodeHandle::new(self.nodes.len() - 1);
149 self.add_element_to_rank(node, 0, false);
150 }
151 self.verify();
152 }
153
154 pub fn successors(&self, from: NodeHandle) -> &Vec<NodeHandle> {
155 &self.nodes[from.idx].successors
156 }
157
158 pub fn predecessors(&self, from: NodeHandle) -> &Vec<NodeHandle> {
159 &self.nodes[from.idx].predecessors
160 }
161
162 pub fn single_pred(&self, from: NodeHandle) -> Option<NodeHandle> {
163 if self.nodes[from.idx].predecessors.len() == 1 {
164 return Some(self.nodes[from.idx].predecessors[0]);
165 }
166 None
167 }
168
169 pub fn single_succ(&self, from: NodeHandle) -> Option<NodeHandle> {
170 if self.nodes[from.idx].successors.len() == 1 {
171 return Some(self.nodes[from.idx].successors[0]);
172 }
173 None
174 }
175
176 pub fn verify(&self) {
177 if self.validate {
178 for node in &self.nodes {
180 for edge in &node.successors {
181 assert!(edge.idx < self.nodes.len());
182 }
183 }
184
185 for (i, node) in self.nodes.iter().enumerate() {
187 let from = NodeHandle::from(i);
188 for dest in node.successors.iter() {
189 let reachable =
190 self.is_reachable(*dest, from) && from != *dest;
191 assert!(!reachable, "We found a cycle!");
192 }
193 }
194
195 assert_eq!(self.count_nodes_in_ranks(), self.len());
197 }
198 }
199
200 pub fn len(&self) -> usize {
201 self.nodes.len()
202 }
203
204 pub fn is_empty(&self) -> bool {
205 self.nodes.is_empty()
206 }
207
208 fn is_reachable_inner(
211 &self,
212 from: NodeHandle,
213 to: NodeHandle,
214 visited: &mut Vec<bool>,
215 ) -> bool {
216 if from == to {
217 return true;
218 }
219
220 if visited[from.idx] {
222 return false;
223 }
224
225 visited[from.idx] = true;
227
228 let from_node = &self.nodes[from.idx];
229 for edge in &from_node.successors {
230 if self.is_reachable_inner(*edge, to, visited) {
231 return true;
232 }
233 }
234
235 visited[from.idx] = false;
237 false
238 }
239
240 pub fn is_reachable(&self, from: NodeHandle, to: NodeHandle) -> bool {
242 if from == to {
243 return true;
244 }
245
246 let mut visited = Vec::new();
247 visited.resize(self.nodes.len(), false);
248 self.is_reachable_inner(from, to, &mut visited)
249 }
250
251 fn topological_sort(&self) -> Vec<NodeHandle> {
254 let mut order: Vec<NodeHandle> = Vec::new();
256
257 let mut visited = Vec::new();
259 visited.resize(self.nodes.len(), false);
260
261 let mut worklist: Vec<(NodeHandle, bool)> = Vec::new();
265
266 for n in self.iter() {
268 worklist.push((n, false));
269 }
270
271 while let Some((current, cmd)) = worklist.pop() {
272 if cmd {
274 order.push(current);
275 continue;
276 }
277
278 if visited[current.idx] {
280 continue;
281 }
282
283 visited[current.idx] = true;
284
285 worklist.push((current, true));
287
288 let node = &self.nodes[current.idx];
290 for edge in &node.successors {
291 worklist.push((*edge, false));
292 }
293 }
294
295 order.reverse();
297 order
298 }
299
300 pub fn num_levels(&self) -> usize {
304 self.ranks.len()
305 }
306
307 pub fn row_mut(&mut self, level: usize) -> &mut Vec<NodeHandle> {
309 assert!(level < self.ranks.len(), "Invalid rank");
310 &mut self.ranks[level]
311 }
312
313 pub fn row(&self, level: usize) -> &Vec<NodeHandle> {
315 assert!(level < self.ranks.len(), "Invalid rank");
316 &self.ranks[level]
317 }
318
319 pub fn ranks(&self) -> &RankType {
321 &self.ranks
322 }
323
324 pub fn ranks_mut(&mut self) -> &mut RankType {
326 &mut self.ranks
327 }
328
329 pub fn is_first_in_row(&self, elem: NodeHandle, level: usize) -> bool {
331 if level >= self.ranks.len() || self.ranks[level].is_empty() {
332 return false;
333 }
334 self.ranks[level][0] == elem
335 }
336
337 pub fn is_last_in_row(&self, elem: NodeHandle, level: usize) -> bool {
339 if level >= self.ranks.len() || self.ranks[level].is_empty() {
340 return false;
341 }
342 let last_idx = self.ranks[level].len() - 1;
343 self.ranks[level][last_idx] == elem
344 }
345
346 fn add_element_to_rank(
351 &mut self,
352 elem: NodeHandle,
353 level: usize,
354 prepend: bool,
355 ) {
356 while self.ranks.len() < level + 1 {
357 self.ranks.push(Vec::new());
358 }
359
360 if prepend {
361 self.ranks[level].insert(0, elem);
362 } else {
363 self.ranks[level].push(elem);
364 }
365 }
366
367 pub fn recompute_node_ranks(&mut self) {
369 assert!(!self.is_empty(), "Sorting an empty graph");
370 let order = self.topological_sort();
371 let levels = self.compute_levels(&order);
372 self.ranks.clear();
373 for (i, level) in levels.iter().enumerate() {
374 self.add_element_to_rank(NodeHandle::from(i), *level, false);
375 }
376 }
377
378 fn count_nodes_in_ranks(&self) -> usize {
381 let mut cnt = 0;
382 for row in self.ranks.iter() {
383 cnt += row.len();
384 }
385 cnt
386 }
387
388 pub fn update_node_rank_level(
391 &mut self,
392 node: NodeHandle,
393 new_level: usize,
394 insert_before: Option<NodeHandle>,
395 ) {
396 let curr_level = self.level(node);
397 let level = &mut self.ranks[curr_level];
398 let idx = level
399 .iter()
400 .position(|x| *x == node)
401 .expect("node not found");
402 level.remove(idx);
403
404 while self.ranks.len() < new_level + 1 {
406 self.ranks.push(Vec::new());
407 }
408
409 if let Option::Some(marker) = insert_before {
410 let row = &mut self.ranks[new_level];
411 for i in 0..row.len() {
412 if row[i] == marker {
413 row.insert(i, node);
414 return;
415 }
416 }
417 panic!("Can't find the marker node in the array");
418 }
419
420 self.ranks[new_level].push(node);
421 assert_eq!(self.level(node), new_level);
422 }
423
424 pub fn level(&self, node: NodeHandle) -> usize {
426 assert!(node.get_index() < self.len(), "Node not in the dag");
427 for (i, row) in self.ranks.iter().enumerate() {
428 if row.contains(&node) {
429 return i;
430 }
431 }
432 panic!("Unexpected node. Is the graph ranked?");
433 }
434
435 fn compute_levels(&self, order: &[NodeHandle]) -> Vec<usize> {
438 let mut levels: Vec<usize> = Vec::new();
439 assert_eq!(order.len(), self.nodes.len());
440
441 levels.resize(self.nodes.len(), 0);
443
444 for src in order {
446 for dest in self.nodes[src.idx].successors.iter() {
448 if src.idx == dest.idx {
450 continue;
451 }
452 levels[dest.idx] =
453 cmp::max(levels[dest.idx], levels[src.idx] + 1);
454 }
455 }
456
457 for src in order {
459 for dest in self.nodes[src.idx].successors.iter() {
460 assert!(levels[dest.idx] >= levels[src.idx]);
461 }
462 }
463
464 levels
465 }
466}
467
468impl Default for DAG {
469 fn default() -> Self {
470 Self::new()
471 }
472}
473
474#[test]
475fn test_simple_construction() {
476 let mut g = DAG::new();
477 let h0 = g.new_node();
478 g.verify();
479
480 let h1 = g.new_node();
481 let h2 = g.new_node();
482 let h3 = g.new_node();
483 let h4 = g.new_node();
484
485 assert_ne!(h0, h1);
486 assert_ne!(h1, h2);
487
488 g.add_edge(h0, h1);
489 g.add_edge(h1, h2);
490 g.add_edge(h0, h2);
491 g.add_edge(h2, h3);
492 g.add_edge(h3, h4);
493
494 g.verify();
495
496 let order = g.topological_sort();
497 let levels = g.compute_levels(&order);
498 assert_eq!(order.len(), g.len());
499 assert_eq!(levels.len(), g.len());
500
501 for i in 0..g.len() {
502 println!("{}) node {}, level {}", i, order[i].idx, levels[i]);
503 }
504}
505
506#[test]
507fn test_rank_api() {
508 let mut g = DAG::new();
509 let h0 = g.new_node();
510 let h1 = g.new_node();
511 let h2 = g.new_node();
512
513 g.add_edge(h0, h1);
514 g.add_edge(h1, h2);
515
516 g.recompute_node_ranks();
517 g.verify();
518
519 assert_eq!(g.level(h0), 0);
520 assert_eq!(g.level(h1), 1);
521 assert_eq!(g.level(h2), 2);
522
523 let r1 = g.remove_edge(h0, h1);
524 let r2 = g.remove_edge(h0, h1);
525 assert!(r1);
527 assert!(!r2);
529}