interactive_dag/
lib.rs

1use fnv::{FnvHashMap, FnvHashSet};
2use generational_arena::{Arena, Index};
3use std::borrow::Borrow;
4use std::collections::{hash_map, VecDeque};
5use std::hash::Hash;
6use std::ops::RangeBounds;
7
8#[cfg(any(test, feature = "test-utils"))]
9pub mod naive;
10
11/// An incremental directed acyclic graph, however the word 'acyclic' is used, this structure allows
12/// the existence of cycles in the hope that they will eventually be resolved and provides APIs to
13/// report the cycles to the user of the structure.
14pub struct Dag<N> {
15    /// The metadata of each node.
16    entries: Arena<Entry>,
17    /// Match each node to the index of the arena.
18    nodes: FnvHashMap<N, Index>,
19    /// All of the cycles in this graph.
20    cycles: FnvHashSet<Cycle>,
21    /// The first order that is available to use.
22    next_order: u64,
23}
24
25/// An entry in the DAG which contains the metadata about an edge.
26#[derive(Default)]
27struct Entry {
28    forward: FnvHashSet<Index>,
29    backward: FnvHashSet<Index>,
30    order: u64,
31}
32
33#[derive(Debug, Eq, PartialEq)]
34pub enum Error {
35    /// The node which the operation was being performed on does not exits in the graph.
36    NotFound,
37    /// There was an attempt to create/destroy a connection with same node.
38    SelfLoop,
39}
40
41/// A helper structure that keeps track of a visit stack and the set of already visited nodes and
42/// visit the nodes in the graph in the given direction.
43struct DagTraverser {
44    /// The nodes that we should visit.
45    stack: VecDeque<Index>,
46    /// Nodes that we have already visited.
47    visited: FnvHashSet<Index>,
48    /// The direction in which we should be moving.
49    direction: Direction,
50}
51
52/// The direction at which we should move to find.
53enum Direction {
54    /// Move the traverser in the forward direction and visit the children.
55    Forward,
56    /// Move the traverser in the backward direction and visit the parents.
57    Backward,
58}
59
60#[derive(Copy, Clone, Eq, PartialEq)]
61enum ControlFlow {
62    /// Stop the traversal immediately.
63    Stop,
64    /// Continue as normal.
65    Continue,
66    // /// Skip visiting the children.
67    // SkipChildren,
68}
69
70/// A visitor that stops when it finds a target node.
71struct SearchVisitor {
72    target: Index,
73    found: bool,
74}
75
76/// A visitor that just collects all of the nodes that it sees.
77#[derive(Default)]
78struct CollectVisitor {
79    collected: Vec<(Index, u64)>,
80}
81
82trait Visitor {
83    fn visit(&mut self, index: &Index, order: u64) -> ControlFlow;
84}
85
86#[derive(Hash, Eq, PartialEq, Debug)]
87struct Cycle(Index, Index);
88
89enum GraphChange {
90    DeleteNode(Index),
91    DeleteEdge(Index, Index),
92}
93
94impl<N> Dag<N>
95where
96    N: Hash + Eq,
97{
98    /// Create a new empty DAG instance.
99    pub fn new() -> Self {
100        Self {
101            entries: Arena::new(),
102            nodes: FnvHashMap::default(),
103            cycles: FnvHashSet::default(),
104            next_order: 0,
105        }
106    }
107
108    /// Insert a new node to the graph. Returns `false` if the node already existed in the graph.
109    pub fn insert(&mut self, node: N) -> bool {
110        if let hash_map::Entry::Vacant(e) = self.nodes.entry(node) {
111            let entry = Entry {
112                order: self.next_order,
113                ..Entry::default()
114            };
115            let index = self.entries.insert(entry);
116            e.insert(index);
117            self.next_order += 1;
118            true
119        } else {
120            false
121        }
122    }
123
124    /// Remove a node from the graph. Returns the node which got removed in an `Option`, if the node
125    /// did not exist in the first place `None` is returned.
126    pub fn remove<Q: ?Sized>(&mut self, node: &Q) -> Option<N>
127    where
128        N: Borrow<Q>,
129        Q: Hash + Eq,
130    {
131        let (node, index) = if let Some((n, index)) = self.nodes.remove_entry(node) {
132            (n, index)
133        } else {
134            return None;
135        };
136
137        let entry = self.entries.remove(index).unwrap();
138
139        for i in entry.forward {
140            let entry = self.entries.get_mut(i).unwrap();
141            entry.backward.remove(&index);
142        }
143
144        for i in entry.backward {
145            let entry = self.entries.get_mut(i).unwrap();
146            entry.forward.remove(&index);
147        }
148
149        if !self.cycles.is_empty() {
150            self.update_cycles(GraphChange::DeleteNode(index));
151        }
152
153        Some(node)
154    }
155
156    //// Insert a new edge to the graph. Returns an error if either one of the nodes is not present
157    /// in the graph, `Ok(false)` is returned if the connection already existed in the graph.
158    ///
159    /// If the same node is passed for both the values of `v` and `u` then `Err(Error::SelfLoop)`
160    /// is returned.
161    ///
162    /// ```
163    /// use interactive_dag::{Dag, Error};
164    /// let mut g = Dag::<u32>::new();
165    /// assert_eq!(g.connect(&0, &0), Err(Error::NotFound));
166    /// g.insert(0);
167    /// assert_eq!(g.connect(&0, &0), Err(Error::SelfLoop));
168    /// g.insert(1);
169    /// assert_eq!(g.connect(&0, &1), Ok(true));
170    /// assert_eq!(g.connect(&0, &1), Ok(false));
171    /// ```
172    pub fn connect<Q: ?Sized>(&mut self, v: &Q, u: &Q) -> Result<bool, Error>
173    where
174        N: Borrow<Q>,
175        Q: Hash + Eq,
176    {
177        let v_index = *self.nodes.get(v).ok_or(Error::NotFound)?;
178        let u_index = *self.nodes.get(u).ok_or(Error::NotFound)?;
179
180        // self loops are not allowed.
181        if v_index == u_index {
182            return Err(Error::SelfLoop);
183        }
184
185        if self
186            .entries
187            .get(v_index)
188            .unwrap()
189            .forward
190            .contains(&u_index)
191        {
192            // The connection already exits.
193            return Ok(false);
194        }
195
196        self.add_edge_helper(v_index, u_index, false);
197
198        // Perform the insertion.
199        let tmp = self.entries.get2_mut(v_index, u_index);
200        let v_entry = tmp.0.unwrap();
201        let u_entry = tmp.1.unwrap();
202        v_entry.forward.insert(u_index);
203        u_entry.backward.insert(v_index);
204
205        Ok(true)
206    }
207
208    /// Removes an edge from the graph.
209    pub fn disconnect<Q: ?Sized>(&mut self, v: &Q, u: &Q) -> Result<bool, Error>
210    where
211        N: Borrow<Q>,
212        Q: Hash + Eq,
213    {
214        let v_index = self.nodes.get(v).ok_or(Error::NotFound)?;
215        let u_index = self.nodes.get(u).ok_or(Error::NotFound)?;
216
217        if v_index == u_index {
218            return Err(Error::SelfLoop);
219        }
220
221        {
222            let tmp = self.entries.get2_mut(*v_index, *u_index);
223            let v_entry = tmp.0.unwrap();
224            let u_entry = tmp.1.unwrap();
225
226            if !v_entry.forward.remove(u_index) {
227                // the connection does not even exists.
228                return Ok(false);
229            }
230
231            u_entry.backward.remove(v_index);
232        }
233
234        if !self.cycles.is_empty() {
235            self.update_cycles(GraphChange::DeleteEdge(*v_index, *u_index));
236        }
237
238        Ok(true)
239    }
240
241    /// Returns `true` if the graph contains the given node.
242    pub fn contains<Q: ?Sized>(&self, v: &Q) -> bool
243    where
244        N: Borrow<Q>,
245        Q: Hash + Eq,
246    {
247        self.nodes.contains_key(v)
248    }
249
250    /// Returns `true` if there is a direct edge connection `v` to `u`.
251    pub fn is_connected<Q: ?Sized>(&self, v: &Q, u: &Q) -> bool
252    where
253        N: Borrow<Q>,
254        Q: Hash + Eq,
255    {
256        let (v_index, u_index) = match (self.nodes.get(v), self.nodes.get(u)) {
257            (Some(v_index), Some(u_index)) => (v_index, u_index),
258            _ => return false,
259        };
260        // check the existence of the edge backward.
261        self.entries
262            .get(*u_index)
263            .unwrap()
264            .backward
265            .contains(v_index)
266    }
267
268    /// Returns `true` if there is a path from `v` to `u`.
269    pub fn is_reachable<Q: ?Sized>(&self, v: &Q, u: &Q) -> bool
270    where
271        N: Borrow<Q>,
272        Q: Hash + Eq,
273    {
274        let (v_index, u_index) = match (self.nodes.get(v), self.nodes.get(u)) {
275            (Some(v_index), Some(u_index)) => (v_index, u_index),
276            _ => return false,
277        };
278        if v_index == u_index {
279            return false;
280        }
281        let v_entry = self.entries.get(*v_index).unwrap();
282        let u_entry = self.entries.get(*u_index).unwrap();
283        // construct the traverser that searches for `u`.
284        let mut visitor = SearchVisitor::new(*u_index);
285        let mut traverser = DagTraverser::new(Direction::Forward);
286        traverser.push_index(*v_index);
287        // return the result based on the searches.
288        if v_entry.order < u_entry.order {
289            traverser.traverse(self, 0..=u_entry.order, &mut visitor);
290            visitor.found
291        } else if self.cycles.is_empty() {
292            // If there is no cycle in this graph, then there is not gonna be a path from v->u.
293            false
294        } else {
295            traverser.traverse(self, 0..=u64::MAX, &mut visitor);
296            visitor.found
297        }
298    }
299
300    #[inline(always)]
301    fn update_cycles(&mut self, change: GraphChange) {
302        // The strategy is simple:
303        // 1. Remove every cycle that is immediately effected by this change.
304        // 2. Iterate over every cycle and check if they are resolved.
305        // 3. If so, remove the cycle and attend inserting it again to reorder it.
306
307        let cycles = std::mem::take(&mut self.cycles);
308
309        for cycle in cycles {
310            if change.should_remove(&cycle) {
311                // the cycle should immediately be removed. There is nothing else to do.
312                continue;
313            }
314
315            let v = cycle.0;
316            let u = cycle.1;
317
318            // check for the existence of v -> u.
319            // the add_edge_helper function checks the existence of u -> v, and if it exits then
320            // it will insert a new cycle to self.cycles.
321            let mut visitor = SearchVisitor::new(u);
322            let mut traverser = DagTraverser::new(Direction::Forward);
323            traverser.push_index(v);
324            traverser.traverse(self, 0..=u64::MAX, &mut visitor);
325
326            if !visitor.found {
327                // The cycle is resolved so we can move on. But the edge is not removed so this
328                // means we have to do something to reorder the graph.
329                continue;
330            }
331
332            self.add_edge_helper(v, u, true);
333        }
334    }
335
336    /// Performs the necessary reordering upon insertion of a new edge, it is also called from
337    /// update_cycles for when a cycle is resolved.
338    fn add_edge_helper(&mut self, v_index: Index, u_index: Index, visit_all: bool) {
339        let (v_order, u_order) = {
340            let v_entry = self.entries.get(v_index).unwrap();
341            let u_entry = self.entries.get(u_index).unwrap();
342            (v_entry.order, u_entry.order)
343        };
344
345        // If we're already sorted, don't do anything.
346        // if v_order <= u_order {
347        //     return;
348        // }
349
350        let mut traverser = DagTraverser::new(Direction::Forward);
351        let mut visited_forward = CollectVisitor::default();
352        let mut visited_backward = CollectVisitor::default();
353
354        let range = if self.cycles.is_empty() && !visit_all {
355            0..=v_order
356        } else {
357            0..=u64::MAX
358        };
359
360        // Start from `u` and move forward, here we want to see if there is a path from u -> v, and
361        // in that case we have found a cycle.
362        traverser.push_index(u_index);
363        traverser.traverse(self, range, &mut visited_forward);
364
365        if traverser.has_visited(&v_index) {
366            // We have found a cycle. So we should report it and keep track of it.
367            self.cycles.insert(Cycle(v_index, u_index));
368        } else {
369            // Reorder the graph to maintain the topological ordering.
370            traverser.direction = Direction::Backward;
371            traverser.push_index(v_index);
372            traverser.traverse(self, (u_order + 1).., &mut visited_backward);
373            let visited_forward = visited_forward.collected;
374            let visited_backward = visited_backward.collected;
375            self.reorder(visited_forward, visited_backward);
376        }
377    }
378
379    fn reorder(
380        &mut self,
381        mut visited_forward: Vec<(Index, u64)>,
382        mut visited_backward: Vec<(Index, u64)>,
383    ) {
384        // sort the nodes by their original order.
385        visited_forward.sort_by_key(|(_, order)| *order);
386        visited_backward.sort_by_key(|(_, order)| *order);
387
388        let len1 = visited_forward.len();
389        let len2 = visited_backward.len();
390        let mut i1 = 0usize;
391        let mut i2 = 0usize;
392        let mut index_iter = visited_backward.iter().chain(visited_forward.iter());
393
394        while i1 < len1 && i2 < len2 {
395            let (_, o1) = visited_forward[i1];
396            let (_, o2) = visited_backward[i2];
397
398            let index = index_iter.next().unwrap().0;
399            self.entries.get_mut(index).unwrap().order = if o1 < o2 {
400                i1 += 1;
401                o1
402            } else {
403                i2 += 1;
404                o2
405            };
406        }
407
408        while i1 < len1 {
409            let index = index_iter.next().unwrap().0;
410            self.entries.get_mut(index).unwrap().order = visited_forward[i1].1;
411            i1 += 1;
412        }
413
414        while i2 < len2 {
415            let index = index_iter.next().unwrap().0;
416            self.entries.get_mut(index).unwrap().order = visited_backward[i2].1;
417            i2 += 1;
418        }
419    }
420}
421
422impl DagTraverser {
423    /// Create a new traverser that moves in the given direction.
424    pub fn new(direction: Direction) -> Self {
425        Self {
426            direction,
427            stack: VecDeque::new(),
428            visited: FnvHashSet::default(),
429        }
430    }
431
432    /// Returns true if the given element is visited.
433    #[inline(always)]
434    pub fn has_visited(&self, node: &Index) -> bool {
435        self.visited.contains(node)
436    }
437
438    /// Push a new node to the traverser stack so it can be visited.
439    #[inline(always)]
440    pub fn push_index(&mut self, index: Index) {
441        // do not check if the node is visited since this can be a strategy when the controller
442        // wants to traverse the graph forward and backward.
443        self.stack.push_front(index);
444    }
445
446    /// Start visiting the provided graph at the given range with the provided visitor.
447    pub fn traverse<N, R: RangeBounds<u64>, V: Visitor>(
448        &mut self,
449        dag: &Dag<N>,
450        _range: R,
451        visitor: &mut V,
452    ) {
453        while let Some(index) = self.stack.pop_back() {
454            let entry = dag.entries.get(index).unwrap();
455
456            // check if the node is within the range of our traversal.
457            // if !range.contains(&entry.order) {
458            //     continue;
459            // }
460
461            // only visit once.
462            if !self.visited.insert(index) {
463                continue;
464            }
465
466            // pass it to the visitor.
467            match visitor.visit(&index, entry.order) {
468                ControlFlow::Continue => {}
469                ControlFlow::Stop => break,
470                // ControlFlow::SkipChildren => continue,
471            }
472
473            // get the next nodes we have to visit depending on the direction.
474            let to_visit = match self.direction {
475                Direction::Forward => &entry.forward,
476                Direction::Backward => &entry.backward,
477            };
478
479            // count the new items we need to append to the stack.
480            let mut new_items = 0;
481            for v_index in to_visit {
482                if self.visited.contains(v_index) {
483                    // the node is already visited - ignore it.
484                    continue;
485                }
486
487                new_items += 1;
488            }
489
490            // insert the items to the stack so we can visit them later.
491            self.stack.reserve(new_items);
492
493            // perform the insertion of items to the stack.
494            for v_index in to_visit {
495                if self.visited.contains(v_index) {
496                    continue;
497                }
498
499                self.stack.push_front(*v_index);
500            }
501        }
502    }
503}
504
505impl SearchVisitor {
506    pub fn new(target: Index) -> Self {
507        SearchVisitor {
508            target,
509            found: false,
510        }
511    }
512}
513
514impl Visitor for SearchVisitor {
515    #[inline(always)]
516    fn visit(&mut self, index: &Index, _order: u64) -> ControlFlow {
517        if self.found || &self.target == index {
518            self.found = true;
519            ControlFlow::Stop
520        } else {
521            ControlFlow::Continue
522        }
523    }
524}
525
526impl Visitor for CollectVisitor {
527    #[inline(always)]
528    fn visit(&mut self, index: &Index, order: u64) -> ControlFlow {
529        self.collected.push((*index, order));
530        ControlFlow::Continue
531    }
532}
533
534impl Visitor for () {
535    #[inline(always)]
536    fn visit(&mut self, _index: &Index, _order: u64) -> ControlFlow {
537        ControlFlow::Continue
538    }
539}
540
541impl<N> Default for Dag<N>
542where
543    N: Hash + Eq,
544{
545    fn default() -> Self {
546        Self::new()
547    }
548}
549
550impl GraphChange {
551    /// Returns `true` if the cycle should be removed based on this change.
552    #[inline(always)]
553    pub fn should_remove(&self, cycle: &Cycle) -> bool {
554        match self {
555            Self::DeleteEdge(v, u) => &cycle.0 == v && &cycle.1 == u,
556            Self::DeleteNode(v) => &cycle.0 == v || &cycle.1 == v,
557        }
558    }
559}