1use fnv::{FnvHashMap, FnvHashSet};
2use generational_arena::{Arena, Index};
3use std::borrow::Borrow;
4use std::collections::{hash_map, VecDeque};
5use std::hash::Hash;
6use std::ops::RangeBounds;
7
8#[cfg(any(test, feature = "test-utils"))]
9pub mod naive;
10
11pub struct Dag<N> {
15 entries: Arena<Entry>,
17 nodes: FnvHashMap<N, Index>,
19 cycles: FnvHashSet<Cycle>,
21 next_order: u64,
23}
24
25#[derive(Default)]
27struct Entry {
28 forward: FnvHashSet<Index>,
29 backward: FnvHashSet<Index>,
30 order: u64,
31}
32
33#[derive(Debug, Eq, PartialEq)]
34pub enum Error {
35 NotFound,
37 SelfLoop,
39}
40
41struct DagTraverser {
44 stack: VecDeque<Index>,
46 visited: FnvHashSet<Index>,
48 direction: Direction,
50}
51
52enum Direction {
54 Forward,
56 Backward,
58}
59
60#[derive(Copy, Clone, Eq, PartialEq)]
61enum ControlFlow {
62 Stop,
64 Continue,
66 }
69
70struct SearchVisitor {
72 target: Index,
73 found: bool,
74}
75
76#[derive(Default)]
78struct CollectVisitor {
79 collected: Vec<(Index, u64)>,
80}
81
82trait Visitor {
83 fn visit(&mut self, index: &Index, order: u64) -> ControlFlow;
84}
85
86#[derive(Hash, Eq, PartialEq, Debug)]
87struct Cycle(Index, Index);
88
89enum GraphChange {
90 DeleteNode(Index),
91 DeleteEdge(Index, Index),
92}
93
94impl<N> Dag<N>
95where
96 N: Hash + Eq,
97{
98 pub fn new() -> Self {
100 Self {
101 entries: Arena::new(),
102 nodes: FnvHashMap::default(),
103 cycles: FnvHashSet::default(),
104 next_order: 0,
105 }
106 }
107
108 pub fn insert(&mut self, node: N) -> bool {
110 if let hash_map::Entry::Vacant(e) = self.nodes.entry(node) {
111 let entry = Entry {
112 order: self.next_order,
113 ..Entry::default()
114 };
115 let index = self.entries.insert(entry);
116 e.insert(index);
117 self.next_order += 1;
118 true
119 } else {
120 false
121 }
122 }
123
124 pub fn remove<Q: ?Sized>(&mut self, node: &Q) -> Option<N>
127 where
128 N: Borrow<Q>,
129 Q: Hash + Eq,
130 {
131 let (node, index) = if let Some((n, index)) = self.nodes.remove_entry(node) {
132 (n, index)
133 } else {
134 return None;
135 };
136
137 let entry = self.entries.remove(index).unwrap();
138
139 for i in entry.forward {
140 let entry = self.entries.get_mut(i).unwrap();
141 entry.backward.remove(&index);
142 }
143
144 for i in entry.backward {
145 let entry = self.entries.get_mut(i).unwrap();
146 entry.forward.remove(&index);
147 }
148
149 if !self.cycles.is_empty() {
150 self.update_cycles(GraphChange::DeleteNode(index));
151 }
152
153 Some(node)
154 }
155
156 pub fn connect<Q: ?Sized>(&mut self, v: &Q, u: &Q) -> Result<bool, Error>
173 where
174 N: Borrow<Q>,
175 Q: Hash + Eq,
176 {
177 let v_index = *self.nodes.get(v).ok_or(Error::NotFound)?;
178 let u_index = *self.nodes.get(u).ok_or(Error::NotFound)?;
179
180 if v_index == u_index {
182 return Err(Error::SelfLoop);
183 }
184
185 if self
186 .entries
187 .get(v_index)
188 .unwrap()
189 .forward
190 .contains(&u_index)
191 {
192 return Ok(false);
194 }
195
196 self.add_edge_helper(v_index, u_index, false);
197
198 let tmp = self.entries.get2_mut(v_index, u_index);
200 let v_entry = tmp.0.unwrap();
201 let u_entry = tmp.1.unwrap();
202 v_entry.forward.insert(u_index);
203 u_entry.backward.insert(v_index);
204
205 Ok(true)
206 }
207
208 pub fn disconnect<Q: ?Sized>(&mut self, v: &Q, u: &Q) -> Result<bool, Error>
210 where
211 N: Borrow<Q>,
212 Q: Hash + Eq,
213 {
214 let v_index = self.nodes.get(v).ok_or(Error::NotFound)?;
215 let u_index = self.nodes.get(u).ok_or(Error::NotFound)?;
216
217 if v_index == u_index {
218 return Err(Error::SelfLoop);
219 }
220
221 {
222 let tmp = self.entries.get2_mut(*v_index, *u_index);
223 let v_entry = tmp.0.unwrap();
224 let u_entry = tmp.1.unwrap();
225
226 if !v_entry.forward.remove(u_index) {
227 return Ok(false);
229 }
230
231 u_entry.backward.remove(v_index);
232 }
233
234 if !self.cycles.is_empty() {
235 self.update_cycles(GraphChange::DeleteEdge(*v_index, *u_index));
236 }
237
238 Ok(true)
239 }
240
241 pub fn contains<Q: ?Sized>(&self, v: &Q) -> bool
243 where
244 N: Borrow<Q>,
245 Q: Hash + Eq,
246 {
247 self.nodes.contains_key(v)
248 }
249
250 pub fn is_connected<Q: ?Sized>(&self, v: &Q, u: &Q) -> bool
252 where
253 N: Borrow<Q>,
254 Q: Hash + Eq,
255 {
256 let (v_index, u_index) = match (self.nodes.get(v), self.nodes.get(u)) {
257 (Some(v_index), Some(u_index)) => (v_index, u_index),
258 _ => return false,
259 };
260 self.entries
262 .get(*u_index)
263 .unwrap()
264 .backward
265 .contains(v_index)
266 }
267
268 pub fn is_reachable<Q: ?Sized>(&self, v: &Q, u: &Q) -> bool
270 where
271 N: Borrow<Q>,
272 Q: Hash + Eq,
273 {
274 let (v_index, u_index) = match (self.nodes.get(v), self.nodes.get(u)) {
275 (Some(v_index), Some(u_index)) => (v_index, u_index),
276 _ => return false,
277 };
278 if v_index == u_index {
279 return false;
280 }
281 let v_entry = self.entries.get(*v_index).unwrap();
282 let u_entry = self.entries.get(*u_index).unwrap();
283 let mut visitor = SearchVisitor::new(*u_index);
285 let mut traverser = DagTraverser::new(Direction::Forward);
286 traverser.push_index(*v_index);
287 if v_entry.order < u_entry.order {
289 traverser.traverse(self, 0..=u_entry.order, &mut visitor);
290 visitor.found
291 } else if self.cycles.is_empty() {
292 false
294 } else {
295 traverser.traverse(self, 0..=u64::MAX, &mut visitor);
296 visitor.found
297 }
298 }
299
300 #[inline(always)]
301 fn update_cycles(&mut self, change: GraphChange) {
302 let cycles = std::mem::take(&mut self.cycles);
308
309 for cycle in cycles {
310 if change.should_remove(&cycle) {
311 continue;
313 }
314
315 let v = cycle.0;
316 let u = cycle.1;
317
318 let mut visitor = SearchVisitor::new(u);
322 let mut traverser = DagTraverser::new(Direction::Forward);
323 traverser.push_index(v);
324 traverser.traverse(self, 0..=u64::MAX, &mut visitor);
325
326 if !visitor.found {
327 continue;
330 }
331
332 self.add_edge_helper(v, u, true);
333 }
334 }
335
336 fn add_edge_helper(&mut self, v_index: Index, u_index: Index, visit_all: bool) {
339 let (v_order, u_order) = {
340 let v_entry = self.entries.get(v_index).unwrap();
341 let u_entry = self.entries.get(u_index).unwrap();
342 (v_entry.order, u_entry.order)
343 };
344
345 let mut traverser = DagTraverser::new(Direction::Forward);
351 let mut visited_forward = CollectVisitor::default();
352 let mut visited_backward = CollectVisitor::default();
353
354 let range = if self.cycles.is_empty() && !visit_all {
355 0..=v_order
356 } else {
357 0..=u64::MAX
358 };
359
360 traverser.push_index(u_index);
363 traverser.traverse(self, range, &mut visited_forward);
364
365 if traverser.has_visited(&v_index) {
366 self.cycles.insert(Cycle(v_index, u_index));
368 } else {
369 traverser.direction = Direction::Backward;
371 traverser.push_index(v_index);
372 traverser.traverse(self, (u_order + 1).., &mut visited_backward);
373 let visited_forward = visited_forward.collected;
374 let visited_backward = visited_backward.collected;
375 self.reorder(visited_forward, visited_backward);
376 }
377 }
378
379 fn reorder(
380 &mut self,
381 mut visited_forward: Vec<(Index, u64)>,
382 mut visited_backward: Vec<(Index, u64)>,
383 ) {
384 visited_forward.sort_by_key(|(_, order)| *order);
386 visited_backward.sort_by_key(|(_, order)| *order);
387
388 let len1 = visited_forward.len();
389 let len2 = visited_backward.len();
390 let mut i1 = 0usize;
391 let mut i2 = 0usize;
392 let mut index_iter = visited_backward.iter().chain(visited_forward.iter());
393
394 while i1 < len1 && i2 < len2 {
395 let (_, o1) = visited_forward[i1];
396 let (_, o2) = visited_backward[i2];
397
398 let index = index_iter.next().unwrap().0;
399 self.entries.get_mut(index).unwrap().order = if o1 < o2 {
400 i1 += 1;
401 o1
402 } else {
403 i2 += 1;
404 o2
405 };
406 }
407
408 while i1 < len1 {
409 let index = index_iter.next().unwrap().0;
410 self.entries.get_mut(index).unwrap().order = visited_forward[i1].1;
411 i1 += 1;
412 }
413
414 while i2 < len2 {
415 let index = index_iter.next().unwrap().0;
416 self.entries.get_mut(index).unwrap().order = visited_backward[i2].1;
417 i2 += 1;
418 }
419 }
420}
421
422impl DagTraverser {
423 pub fn new(direction: Direction) -> Self {
425 Self {
426 direction,
427 stack: VecDeque::new(),
428 visited: FnvHashSet::default(),
429 }
430 }
431
432 #[inline(always)]
434 pub fn has_visited(&self, node: &Index) -> bool {
435 self.visited.contains(node)
436 }
437
438 #[inline(always)]
440 pub fn push_index(&mut self, index: Index) {
441 self.stack.push_front(index);
444 }
445
446 pub fn traverse<N, R: RangeBounds<u64>, V: Visitor>(
448 &mut self,
449 dag: &Dag<N>,
450 _range: R,
451 visitor: &mut V,
452 ) {
453 while let Some(index) = self.stack.pop_back() {
454 let entry = dag.entries.get(index).unwrap();
455
456 if !self.visited.insert(index) {
463 continue;
464 }
465
466 match visitor.visit(&index, entry.order) {
468 ControlFlow::Continue => {}
469 ControlFlow::Stop => break,
470 }
472
473 let to_visit = match self.direction {
475 Direction::Forward => &entry.forward,
476 Direction::Backward => &entry.backward,
477 };
478
479 let mut new_items = 0;
481 for v_index in to_visit {
482 if self.visited.contains(v_index) {
483 continue;
485 }
486
487 new_items += 1;
488 }
489
490 self.stack.reserve(new_items);
492
493 for v_index in to_visit {
495 if self.visited.contains(v_index) {
496 continue;
497 }
498
499 self.stack.push_front(*v_index);
500 }
501 }
502 }
503}
504
505impl SearchVisitor {
506 pub fn new(target: Index) -> Self {
507 SearchVisitor {
508 target,
509 found: false,
510 }
511 }
512}
513
514impl Visitor for SearchVisitor {
515 #[inline(always)]
516 fn visit(&mut self, index: &Index, _order: u64) -> ControlFlow {
517 if self.found || &self.target == index {
518 self.found = true;
519 ControlFlow::Stop
520 } else {
521 ControlFlow::Continue
522 }
523 }
524}
525
526impl Visitor for CollectVisitor {
527 #[inline(always)]
528 fn visit(&mut self, index: &Index, order: u64) -> ControlFlow {
529 self.collected.push((*index, order));
530 ControlFlow::Continue
531 }
532}
533
534impl Visitor for () {
535 #[inline(always)]
536 fn visit(&mut self, _index: &Index, _order: u64) -> ControlFlow {
537 ControlFlow::Continue
538 }
539}
540
541impl<N> Default for Dag<N>
542where
543 N: Hash + Eq,
544{
545 fn default() -> Self {
546 Self::new()
547 }
548}
549
550impl GraphChange {
551 #[inline(always)]
553 pub fn should_remove(&self, cycle: &Cycle) -> bool {
554 match self {
555 Self::DeleteEdge(v, u) => &cycle.0 == v && &cycle.1 == u,
556 Self::DeleteNode(v) => &cycle.0 == v || &cycle.1 == v,
557 }
558 }
559}