Skip to main content

cedar_policy_core/
transitive_closure.rs

1/*
2 * Copyright Cedar Contributors
3 *
4 * Licensed under the Apache License, Version 2.0 (the "License");
5 * you may not use this file except in compliance with the License.
6 * You may obtain a copy of the License at
7 *
8 *      https://www.apache.org/licenses/LICENSE-2.0
9 *
10 * Unless required by applicable law or agreed to in writing, software
11 * distributed under the License is distributed on an "AS IS" BASIS,
12 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 * See the License for the specific language governing permissions and
14 * limitations under the License.
15 */
16
17//! Module containing code to compute the transitive closure of a graph.
18//! This is a generic utility, and not specific to Cedar.
19
20use std::collections::{HashMap, HashSet};
21use std::fmt::{Debug, Display};
22use std::hash::Hash;
23
24mod err;
25pub use err::*;
26use itertools::Itertools;
27
28/// Trait used to generalize transitive closure computation. This trait should
29/// be implemented for types representing a node in the hierarchy (e.g., the
30/// entity hierarchy) where we need to compute the transitive closure of the
31/// hierarchy starting from only direct adjacencies. This trait is parametrized
32/// by a type `K` which represents a unique identifier for graph nodes.
33pub trait TCNode<K> {
34    /// Extract a unique identifier for the node.
35    fn get_key(&self) -> K;
36
37    /// Add an edge out off this node to the node with key `k`.
38    fn add_edge_to(&mut self, k: K);
39
40    /// Retrieve an iterator for the edges out of this node.
41    fn out_edges(&self) -> Box<dyn Iterator<Item = &K> + '_>;
42
43    /// Return true when their is an edge between this node and the node with
44    /// key `k`.
45    fn has_edge_to(&self, k: &K) -> bool;
46
47    /// Resets edges to base
48    fn reset_edges(&mut self);
49
50    /// Retrieves an iterator for direct edges out of this node.
51    fn direct_edges(&self) -> Box<dyn Iterator<Item = &K> + '_> {
52        self.out_edges()
53    }
54}
55
56/// Given Graph as a map from keys with type `K` to implementations of `TCNode`
57/// with type `V`, compute the transitive closure of the hierarchy. In case of
58/// error, the result contains an error structure `Err<K>` which contains the
59/// keys (with type `K`) for the nodes in the graph which caused the error.
60/// If `enforce_dag` then also check that the hierarchy is a DAG
61pub fn compute_tc<K, V>(nodes: &mut HashMap<K, V>, enforce_dag: bool) -> Result<(), K>
62where
63    K: Clone + Eq + Hash + Debug + Display,
64    V: TCNode<K>,
65{
66    // Always use the SCC-based algorithm which correctly handles cycles.
67    // The single-pass algorithm is only correct for DAGs and produces
68    // incomplete results when the graph contains cycles of length >= 3.
69    cyclic_tc(nodes);
70    if enforce_dag {
71        return enforce_dag_from_tc(nodes);
72    }
73    Ok(())
74}
75
76/// Given Graph as a map from keys with type `K` to implementations of `TCNode`
77/// with type `V`, repair the transitive closure of the hierarchy. The below code
78/// will assume that for each `node` in `nodes` except the nodes appearing in
79/// `nodes_to_fix`, the out-going edges of `node` will contain all ancestors of `node`.
80/// That is we may assume the transitive closure for all such nodes is correct while
81/// computing the transitive closure of each node appearing in `nodes_to_fix`.
82/// In case of error, the result contains an error structure `Err<K>` which contains
83/// the keys (with type `K`) for the nodes in the graph which caused the error.
84/// If `enforce_dag` then also check that the heirarchy is a DAG
85pub fn repair_tc<K, V>(
86    nodes_to_fix: HashSet<K>,
87    nodes: &mut HashMap<K, V>,
88    enforce_dag: bool,
89) -> Result<(), K>
90where
91    K: Clone + Eq + Hash + Debug + Display,
92    V: TCNode<K>,
93{
94    let seen: HashSet<K> = nodes
95        .keys()
96        .filter(|x| !nodes_to_fix.contains(x))
97        .cloned()
98        .collect();
99
100    // If the caller does not want to check that the graph is a DAG,
101    // we assume that the graph is acyclic during the below call.
102    // This allows the below call to do a single scan of each node
103    // rather than two scans of each node.
104    compute_tc_internal::<K, V>(nodes_to_fix.iter().cloned(), nodes, seen, enforce_dag);
105
106    if enforce_dag {
107        // TC is correct by construction (same as compute_tc). Only need to
108        // check for self-loops, and only on nodes whose edges were modified.
109        return enforce_dag_from_tc_for(&nodes_to_fix, nodes);
110    }
111    Ok(())
112}
113
114/// Saturate the out-going edges of each node in `node_ids` to include
115/// all reachable ancestors within the graph represnted by `nodes`.
116/// Assume that all nodes appearing in `seen` already satisfy this property.
117/// If `detect_cyles` is false, we assume the the graph represented by `nodes`
118/// is a DAG so that we may perform a single scan over the graph. Otherwise,
119/// we scan each node twice. This is sufficient for detecting cycles and for computing
120/// the exact TC for graphs containing simple cycles. For more complex cyclic graphs,
121/// the below code computes enough of the transtive closure to ensure that if one
122/// calls `enforce_dag_from_tc` on `nodes` after this function returns then it will
123/// correctly detect any cycles (simple or complex).
124fn compute_tc_internal<K, V>(
125    node_ids: impl Iterator<Item = K>,
126    nodes: &mut HashMap<K, V>,
127    mut seen: HashSet<K>,
128    detect_cyles: bool,
129) where
130    K: Clone + Eq + Hash,
131    V: TCNode<K>,
132{
133    for node_id in node_ids {
134        // When detect_cycles is false, skip nodes already in `seen` (single visit).
135        // When detect_cycles is true, always visit — the second pass propagates
136        // enough TC edges for self-loop detection on cyclic graphs.
137        if detect_cyles || seen.insert(node_id.clone()) {
138            add_ancestors(&node_id, nodes, &mut seen);
139        }
140    }
141}
142
143fn cyclic_tc<K, V>(nodes: &mut HashMap<K, V>)
144where
145    K: Clone + Eq + Hash + Debug,
146    V: TCNode<K>,
147{
148    let node_ids = nodes.keys().map(K::clone).collect::<Vec<K>>();
149    let mut order_visited = HashMap::new();
150    let mut root = HashMap::new();
151    let mut vstack = Vec::new();
152    let mut cstack = Vec::new();
153    let mut component = HashMap::new();
154    let mut comp_succ = Vec::new();
155    let mut comp_elts = Vec::new();
156    for node_id in node_ids {
157        if !order_visited.contains_key(&node_id) {
158            cyclic_tc_internal(
159                &node_id,
160                nodes,
161                &mut order_visited,
162                &mut root,
163                &mut vstack,
164                &mut cstack,
165                &mut component,
166                &mut comp_succ,
167                &mut comp_elts,
168            );
169        }
170    }
171    // component_tc => nodes_tc
172    for comp_id in 0..comp_elts.len() {
173        let mut elt_succ = HashSet::new();
174        #[expect(
175            clippy::indexing_slicing,
176            reason = "`comp_succ` and `comp_elts` must have the same length, thus `comp_id` is a valid index to `comp_succ`."
177        )]
178        for comp_parent_id in comp_succ[comp_id].iter() {
179            #[expect(
180                clippy::indexing_slicing,
181                reason = "`comp_parent_id` must be a valid component id to be inserted into `comp_succ` therefore must exist within `comp_elts`."
182            )]
183            for node_id in comp_elts[*comp_parent_id].iter() {
184                // not fine to consume here
185                elt_succ.insert(node_id.clone());
186            }
187        }
188
189        #[expect(
190            clippy::indexing_slicing,
191            reason = "`comp_id` in [0, |`comp_elts`|) is a valid index into `comp_elts`."
192        )]
193        for node_id in comp_elts[comp_id].iter() {
194            let node = match nodes.get_mut(node_id) {
195                Some(node) => node,
196                None => continue,
197            };
198            for parent_id in elt_succ.iter() {
199                node.add_edge_to(parent_id.clone());
200            }
201        }
202    }
203}
204
205#[expect(
206    clippy::too_many_arguments,
207    reason = "internal function in complex algorithm"
208)]
209fn cyclic_tc_internal<K, V>(
210    node_id: &K,
211    nodes: &HashMap<K, V>,
212    order_visited: &mut HashMap<K, usize>,
213    root: &mut HashMap<K, K>,
214    vstack: &mut Vec<K>,
215    cstack: &mut Vec<usize>,
216    component: &mut HashMap<K, usize>,
217    comp_succ: &mut Vec<HashSet<usize>>,
218    comp_elts: &mut Vec<HashSet<K>>,
219) where
220    K: Clone + Eq + Hash + Debug,
221    V: TCNode<K>,
222{
223    let node_order = order_visited.len();
224    // when was the root of this node's component visited?
225    // initially the root of this node's component is this node itself
226    // keeping track in auxillary function to avoid re-fetching in a loop
227    let mut root_order = node_order;
228    order_visited.insert(node_id.clone(), node_order);
229    root.insert(node_id.clone(), node_id.clone());
230    vstack.push(node_id.clone());
231    let height = cstack.len();
232    let mut self_loop = false;
233    let out_edges = match nodes.get(node_id) {
234        Some(node) => node.out_edges().collect(),
235        None => Vec::new(),
236    };
237    for parent_id in out_edges {
238        if node_id == parent_id {
239            self_loop = true;
240        } else {
241            // The edge from node_id to parent_id is a forward edge iff
242            // node_id is visited before parent_id and we do not visit
243            // parent_id from node_id (i.e., we do not recursively call
244            // cyclic_tc_internal on parent_id from this call).
245            let mut maybe_forward_edge = true;
246            if !order_visited.contains_key(parent_id) {
247                maybe_forward_edge = false;
248                cyclic_tc_internal(
249                    parent_id,
250                    nodes,
251                    order_visited,
252                    root,
253                    vstack,
254                    cstack,
255                    component,
256                    comp_succ,
257                    comp_elts,
258                );
259            }
260            match component.get(parent_id) {
261                None => {
262                    #[expect(
263                        clippy::expect_used,
264                        reason = "`parent_id` must have been visited either by a previous call or just above"
265                    )]
266                    let parent_root = root
267                        .get(parent_id)
268                        .expect("Parent has been visited so it must have a root.");
269                    #[expect(
270                        clippy::expect_used,
271                        reason = "in order for `parent_root` to be the parent of `parent_id` it must have been visited."
272                    )]
273                    let parent_root_order = order_visited
274                        .get(parent_root)
275                        .expect("The parent's root must have been visited.");
276                    if *parent_root_order < root_order {
277                        root_order = *parent_root_order;
278                        root.insert(node_id.clone(), parent_root.clone());
279                    }
280                }
281                Some(parent_component) => {
282                    #[expect(
283                        clippy::expect_used,
284                        reason = "`parent_id` must have been visited either by a previous call or just above"
285                    )]
286                    let parent_order = order_visited
287                        .get(parent_id)
288                        .expect("The parent must have been traversed by this point.");
289                    // if not a forward edge
290                    if !(maybe_forward_edge && &node_order < parent_order) {
291                        cstack.push(*parent_component);
292                    }
293                }
294            }
295        }
296    } // end for loop over parents
297    #[expect(
298        clippy::expect_used,
299        reason = "`node_id` must have a root. It was inserted at the begining of this function"
300    )]
301    let node_root = root
302        .get(node_id)
303        .expect("Node must have a root by this point.");
304    // if this node is the root of its connected component
305    if node_id == node_root {
306        let component_id = comp_elts.len();
307        let mut succ = HashSet::new();
308        let mut elmts = HashSet::new();
309        #[expect(
310            clippy::expect_used,
311            reason = " The vertex stack must not be empty because at least node_id must be on the stack!"
312        )]
313        if self_loop || vstack.last().expect("vertex stack must be non-empty") != node_id {
314            succ.insert(component_id);
315        }
316        let mut cstack_tail = cstack.split_off(height);
317        // sort by topological order of the components, which should be equivalent to the reverse order of their ids
318        // cstack_tail are all of the components reachable (1 step) from any node within this component
319        cstack_tail.sort_by(|a, b| b.cmp(a));
320        // iterate through components in topological order
321        for i in 0..cstack_tail.len() {
322            // update this component's successors with next component avoiding duplicate components
323            #[expect(
324                clippy::indexing_slicing,
325                reason = "both `i` and `i - 1` are guaranteed to be valid indices into `cstack_tail`."
326            )]
327            if i == 0 || cstack_tail[i - 1] != cstack_tail[i] {
328                #[expect(
329                    clippy::indexing_slicing,
330                    reason = "`i` is a valid index into `cstack_tail`."
331                )]
332                let tail_elt = cstack_tail[i];
333                if succ.insert(tail_elt) {
334                    #[expect(
335                        clippy::indexing_slicing,
336                        reason = "`tail_elt` is a component id created by a previous call to `cyclic_tc_internal` and thus must be a valid index to `comp_succ`."
337                    )]
338                    succ.extend(comp_succ[tail_elt].clone());
339                }
340            }
341        }
342        loop {
343            #[expect(
344                clippy::expect_used,
345                reason = "The vertex stack `vstack` must contain at least `node_id`"
346            )]
347            let ancestor_id = vstack.pop().expect("Vetex stack must be non-empty");
348            component.insert(ancestor_id.clone(), component_id);
349            elmts.insert(ancestor_id.clone());
350            if *node_id == ancestor_id {
351                break;
352            }
353        }
354        comp_succ.push(succ);
355        comp_elts.push(elmts);
356    }
357}
358
359/// Saturate the out-going edges of the node identified by `node_id` within the graph
360/// represented by `nodes` assuming that each node appearing in `seen` already satisfies
361/// this property. The process works by performing a depth-first search over the ancestors
362/// of `node_id` (and stopping if any ancestor is already in the `seen` set).
363fn add_ancestors<K, V>(node_id: &K, nodes: &mut HashMap<K, V>, seen: &mut HashSet<K>)
364where
365    K: Clone + Eq + Hash,
366    V: TCNode<K>,
367{
368    let mut ancestors: HashSet<K> = HashSet::new();
369    // Track which ancestors we have already read out_edges from, to avoid
370    // redundant work. This is distinct from `ancestors` because an ancestor_id
371    // may appear in `ancestors` (added via some other node's out_edges) before
372    // we have actually read *its* out_edges.
373    let mut explored: HashSet<K> = HashSet::new();
374    let out_edges: Vec<K> = match nodes.get(node_id) {
375        Some(node) => node.out_edges().map(K::clone).collect(),
376        None => return,
377    };
378    for ancestor_id in out_edges {
379        if seen.insert(ancestor_id.clone()) {
380            add_ancestors(&ancestor_id, nodes, seen);
381        }
382        // Only skip reading out_edges if we already explored this exact
383        // ancestor in a previous iteration of this loop.
384        if explored.insert(ancestor_id.clone()) {
385            let ancestor = match nodes.get(&ancestor_id) {
386                Some(ancestor) => ancestor,
387                None => continue,
388            };
389            for grand_ancestor_id in ancestor.out_edges() {
390                ancestors.insert(grand_ancestor_id.clone());
391            }
392        }
393    }
394    #[expect(
395        clippy::expect_used,
396        reason = "this node should always exist because of the check to get `out_edges`"
397    )]
398    let node = nodes
399        .get_mut(node_id)
400        .expect("This node should always exist.");
401    // Do the actual saturation of out-going edges of `node` here to avoid
402    // issues with rust's borrow checker.
403    for ancestor_id in ancestors {
404        node.add_edge_to(ancestor_id);
405    }
406}
407
408/// Given a graph (as a map from keys to `TCNode`), enforce that
409/// all transitive edges are included, ie, the transitive closure has already
410/// been computed and that it is a DAG. If this is not the case, return an appropriate
411/// `TCEnforcementError`.
412pub fn enforce_tc_and_dag<K, V>(entities: &HashMap<K, V>) -> Result<(), K>
413where
414    K: Clone + Eq + Hash + Debug + Display,
415    V: TCNode<K>,
416{
417    let res = enforce_tc(entities);
418    if res.is_ok() {
419        return enforce_dag_from_tc(entities);
420    }
421    res
422}
423
424/// Given a DAG (as a map from keys to `TCNode`), enforce that
425/// all transitive edges are included, i.e., the transitive closure has already
426/// been computed. If this is not the case, return an appropriate
427/// `MissingTcEdge` error.
428fn enforce_tc<K, V>(entities: &HashMap<K, V>) -> Result<(), K>
429where
430    K: Clone + Eq + Hash + Debug + Display,
431    V: TCNode<K>,
432{
433    for entity in entities.values() {
434        for parent_uid in entity.out_edges() {
435            // check that `entity` is also a child of all of this parent's parents
436            if let Some(parent) = entities.get(parent_uid) {
437                for grandparent in parent.out_edges() {
438                    if !entity.has_edge_to(grandparent) {
439                        return Err(TcError::missing_tc_edge(
440                            entity.get_key(),
441                            parent_uid.clone(),
442                            grandparent.clone(),
443                        ));
444                    }
445                }
446            }
447        }
448    }
449    Ok(())
450}
451
452/// Once the transitive closure (as defined above) is computed/enforced for the graph, we have:
453/// \forall u,v,w \in Vertices . (u,v) \in Edges /\ (v,w) \in Edges -> (u,w) \in Edges
454///
455/// Then the graph has a cycle if
456/// \exists v \in Vertices. (v,v) \in Edges
457fn enforce_dag_from_tc<K, V>(entities: &HashMap<K, V>) -> Result<(), K>
458where
459    K: Clone + Eq + Hash + Debug + Display,
460    V: TCNode<K>,
461{
462    for entity in entities.values() {
463        let key = entity.get_key();
464        if entity.out_edges().contains(&key) {
465            return Err(TcError::has_cycle(key));
466        }
467    }
468    Ok(())
469}
470
471/// Like [`enforce_dag_from_tc`] but only checks nodes in `nodes_to_check`.
472/// This is sound after `repair_tc` because only nodes whose edges were modified
473/// could have gained a self-loop.
474fn enforce_dag_from_tc_for<K, V>(
475    nodes_to_check: &HashSet<K>,
476    entities: &HashMap<K, V>,
477) -> Result<(), K>
478where
479    K: Clone + Eq + Hash + Debug + Display,
480    V: TCNode<K>,
481{
482    for key in nodes_to_check {
483        if let Some(entity) = entities.get(key) {
484            if entity.has_edge_to(key) {
485                return Err(TcError::has_cycle(key.clone()));
486            }
487        }
488    }
489    Ok(())
490}
491
492#[cfg(test)]
493#[expect(clippy::panic, clippy::indexing_slicing, reason = "test code")]
494mod tests {
495    use crate::ast::{Entity, EntityUID};
496
497    use super::*;
498
499    #[test]
500    fn basic() {
501        // start with A -> B -> C
502        let mut a = Entity::with_uid(EntityUID::with_eid("A"));
503        a.add_parent(EntityUID::with_eid("B"));
504        let mut b = Entity::with_uid(EntityUID::with_eid("B"));
505        b.add_parent(EntityUID::with_eid("C"));
506        let c = Entity::with_uid(EntityUID::with_eid("C"));
507        let mut entities = HashMap::from([
508            (a.uid().clone(), a),
509            (b.uid().clone(), b),
510            (c.uid().clone(), c),
511        ]);
512        // currently doesn't pass TC enforcement
513        assert!(enforce_tc(&entities).is_err());
514        // compute TC
515        compute_tc(&mut entities, false).expect("Failed to compute transitive closure");
516        let a = &entities[&EntityUID::with_eid("A")];
517        let b = &entities[&EntityUID::with_eid("B")];
518        let c = &entities[&EntityUID::with_eid("C")];
519        // should have added the A -> C edge
520        assert!(a.is_descendant_of(&EntityUID::with_eid("C")));
521        // but we shouldn't have added other edges, like B -> A or C -> A
522        assert!(!b.is_descendant_of(&EntityUID::with_eid("A")));
523        assert!(!c.is_descendant_of(&EntityUID::with_eid("A")));
524        // now it should pass TC enforcement
525        assert!(enforce_tc(&entities).is_ok());
526        // passes cycle check after TC enforcement
527        assert!(enforce_dag_from_tc(&entities).is_ok());
528    }
529
530    #[test]
531    fn reversed() {
532        // same as basic(), but we put the entities in the map in the reverse
533        // order, which shouldn't make a difference
534        let mut a = Entity::with_uid(EntityUID::with_eid("A"));
535        a.add_parent(EntityUID::with_eid("B"));
536        let mut b = Entity::with_uid(EntityUID::with_eid("B"));
537        b.add_parent(EntityUID::with_eid("C"));
538        let c = Entity::with_uid(EntityUID::with_eid("C"));
539        let mut entities = HashMap::from([
540            (c.uid().clone(), c),
541            (b.uid().clone(), b),
542            (a.uid().clone(), a),
543        ]);
544        // currently doesn't pass TC enforcement
545        assert!(enforce_tc(&entities).is_err());
546        // compute TC
547        compute_tc(&mut entities, false).expect("Failed to compute transitive closure");
548        let a = &entities[&EntityUID::with_eid("A")];
549        let b = &entities[&EntityUID::with_eid("B")];
550        let c = &entities[&EntityUID::with_eid("C")];
551        // should have added the A -> C edge
552        assert!(a.is_descendant_of(&EntityUID::with_eid("C")));
553        // but we shouldn't have added other edges, like B -> A or C -> A
554        assert!(!b.is_descendant_of(&EntityUID::with_eid("A")));
555        assert!(!c.is_descendant_of(&EntityUID::with_eid("A")));
556        // now it should pass TC enforcement
557        assert!(enforce_tc(&entities).is_ok());
558        // passes cycle check after TC enforcement
559        assert!(enforce_dag_from_tc(&entities).is_ok());
560    }
561
562    #[test]
563    fn deeper() {
564        // start with A -> B -> C -> D -> E
565        let mut a = Entity::with_uid(EntityUID::with_eid("A"));
566        a.add_parent(EntityUID::with_eid("B"));
567        let mut b = Entity::with_uid(EntityUID::with_eid("B"));
568        b.add_parent(EntityUID::with_eid("C"));
569        let mut c = Entity::with_uid(EntityUID::with_eid("C"));
570        c.add_parent(EntityUID::with_eid("D"));
571        let mut d = Entity::with_uid(EntityUID::with_eid("D"));
572        d.add_parent(EntityUID::with_eid("E"));
573        let e = Entity::with_uid(EntityUID::with_eid("E"));
574        let mut entities = HashMap::from([
575            (a.uid().clone(), a),
576            (b.uid().clone(), b),
577            (c.uid().clone(), c),
578            (d.uid().clone(), d),
579            (e.uid().clone(), e),
580        ]);
581        // currently doesn't pass TC enforcement
582        assert!(enforce_tc(&entities).is_err());
583        // compute TC
584        compute_tc(&mut entities, false).expect("Failed to compute transitive closure");
585        let a = &entities[&EntityUID::with_eid("A")];
586        let b = &entities[&EntityUID::with_eid("B")];
587        let c = &entities[&EntityUID::with_eid("C")];
588        // should have added many edges which we check for
589        assert!(a.is_descendant_of(&EntityUID::with_eid("C")));
590        assert!(a.is_descendant_of(&EntityUID::with_eid("D")));
591        assert!(a.is_descendant_of(&EntityUID::with_eid("E")));
592        assert!(b.is_descendant_of(&EntityUID::with_eid("D")));
593        assert!(b.is_descendant_of(&EntityUID::with_eid("E")));
594        assert!(c.is_descendant_of(&EntityUID::with_eid("E")));
595        // now it should pass TC enforcement
596        assert!(enforce_tc(&entities).is_ok());
597        // passes cycle check after TC enforcement
598        assert!(enforce_dag_from_tc(&entities).is_ok());
599    }
600
601    #[test]
602    fn not_alphabetized() {
603        // same as deeper(), but the entities' parent relations don't follow
604        // alphabetical order. (In case we end up iterating through the map
605        // in alphabetical order, this test will ensure that everything works
606        // even when we aren't processing the entities in hierarchy order.)
607        // start with foo -> bar -> baz -> ham -> eggs
608        let mut foo = Entity::with_uid(EntityUID::with_eid("foo"));
609        foo.add_parent(EntityUID::with_eid("bar"));
610        let mut bar = Entity::with_uid(EntityUID::with_eid("bar"));
611        bar.add_parent(EntityUID::with_eid("baz"));
612        let mut baz = Entity::with_uid(EntityUID::with_eid("baz"));
613        baz.add_parent(EntityUID::with_eid("ham"));
614        let mut ham = Entity::with_uid(EntityUID::with_eid("ham"));
615        ham.add_parent(EntityUID::with_eid("eggs"));
616        let eggs = Entity::with_uid(EntityUID::with_eid("eggs"));
617        let mut entities = HashMap::from([
618            (ham.uid().clone(), ham),
619            (bar.uid().clone(), bar),
620            (foo.uid().clone(), foo),
621            (eggs.uid().clone(), eggs),
622            (baz.uid().clone(), baz),
623        ]);
624        // currently doesn't pass TC enforcement
625        assert!(enforce_tc(&entities).is_err());
626        // compute TC
627        compute_tc(&mut entities, false).expect("Failed to compute transitive closure");
628        let foo = &entities[&EntityUID::with_eid("foo")];
629        let bar = &entities[&EntityUID::with_eid("bar")];
630        let baz = &entities[&EntityUID::with_eid("baz")];
631        // should have added many edges which we check for
632        assert!(foo.is_descendant_of(&EntityUID::with_eid("baz")));
633        assert!(foo.is_descendant_of(&EntityUID::with_eid("ham")));
634        assert!(foo.is_descendant_of(&EntityUID::with_eid("eggs")));
635        assert!(bar.is_descendant_of(&EntityUID::with_eid("ham")));
636        assert!(bar.is_descendant_of(&EntityUID::with_eid("eggs")));
637        assert!(baz.is_descendant_of(&EntityUID::with_eid("eggs")));
638        // now it should pass TC enforcement
639        assert!(enforce_tc(&entities).is_ok());
640        // passes cycle check after TC enforcement
641        assert!(enforce_dag_from_tc(&entities).is_ok());
642    }
643
644    #[test]
645    fn multi_parents() {
646        // start with this:
647        //     B -> C
648        //   /
649        // A
650        //   \
651        //     D -> E
652        let mut a = Entity::with_uid(EntityUID::with_eid("A"));
653        a.add_parent(EntityUID::with_eid("B"));
654        a.add_parent(EntityUID::with_eid("D"));
655        let mut b = Entity::with_uid(EntityUID::with_eid("B"));
656        b.add_parent(EntityUID::with_eid("C"));
657        let c = Entity::with_uid(EntityUID::with_eid("C"));
658        let mut d = Entity::with_uid(EntityUID::with_eid("D"));
659        d.add_parent(EntityUID::with_eid("E"));
660        let e = Entity::with_uid(EntityUID::with_eid("E"));
661        let mut entities = HashMap::from([
662            (a.uid().clone(), a),
663            (b.uid().clone(), b),
664            (c.uid().clone(), c),
665            (d.uid().clone(), d),
666            (e.uid().clone(), e),
667        ]);
668        // currently doesn't pass TC enforcement
669        assert!(enforce_tc(&entities).is_err());
670        // compute TC
671        compute_tc(&mut entities, false).expect("Failed to compute transitive closure");
672        let a = &entities[&EntityUID::with_eid("A")];
673        let b = &entities[&EntityUID::with_eid("B")];
674        let d = &entities[&EntityUID::with_eid("D")];
675        // should have added the A -> C edge and the A -> E edge
676        assert!(a.is_descendant_of(&EntityUID::with_eid("C")));
677        assert!(a.is_descendant_of(&EntityUID::with_eid("E")));
678        // but it shouldn't have added these other edges
679        assert!(!b.is_descendant_of(&EntityUID::with_eid("D")));
680        assert!(!b.is_descendant_of(&EntityUID::with_eid("E")));
681        assert!(!d.is_descendant_of(&EntityUID::with_eid("B")));
682        assert!(!d.is_descendant_of(&EntityUID::with_eid("C")));
683        // now it should pass TC enforcement
684        assert!(enforce_tc(&entities).is_ok());
685        // passes cycle check after TC enforcement
686        assert!(enforce_dag_from_tc(&entities).is_ok());
687    }
688
689    #[test]
690    fn dag() {
691        // start with this:
692        //     B -> C
693        //   /  \
694        // A      D -> E -> H
695        //   \        /
696        //     F -> G
697        let mut a = Entity::with_uid(EntityUID::with_eid("A"));
698        a.add_parent(EntityUID::with_eid("B"));
699        a.add_parent(EntityUID::with_eid("F"));
700        let mut b = Entity::with_uid(EntityUID::with_eid("B"));
701        b.add_parent(EntityUID::with_eid("C"));
702        b.add_parent(EntityUID::with_eid("D"));
703        let c = Entity::with_uid(EntityUID::with_eid("C"));
704        let mut d = Entity::with_uid(EntityUID::with_eid("D"));
705        d.add_parent(EntityUID::with_eid("E"));
706        let mut e = Entity::with_uid(EntityUID::with_eid("E"));
707        e.add_parent(EntityUID::with_eid("H"));
708        let mut f = Entity::with_uid(EntityUID::with_eid("F"));
709        f.add_parent(EntityUID::with_eid("G"));
710        let mut g = Entity::with_uid(EntityUID::with_eid("G"));
711        g.add_parent(EntityUID::with_eid("E"));
712        let h = Entity::with_uid(EntityUID::with_eid("H"));
713        let mut entities = HashMap::from([
714            (a.uid().clone(), a),
715            (b.uid().clone(), b),
716            (c.uid().clone(), c),
717            (d.uid().clone(), d),
718            (e.uid().clone(), e),
719            (f.uid().clone(), f),
720            (g.uid().clone(), g),
721            (h.uid().clone(), h),
722        ]);
723        // currently doesn't pass TC enforcement
724        assert!(enforce_tc(&entities).is_err());
725        // compute TC
726        compute_tc(&mut entities, false).expect("Failed to compute transitive closure");
727        let a = &entities[&EntityUID::with_eid("A")];
728        let b = &entities[&EntityUID::with_eid("B")];
729        let f = &entities[&EntityUID::with_eid("F")];
730        // should have added many edges which we check for
731        assert!(a.is_descendant_of(&EntityUID::with_eid("C")));
732        assert!(a.is_descendant_of(&EntityUID::with_eid("D")));
733        assert!(a.is_descendant_of(&EntityUID::with_eid("E")));
734        assert!(a.is_descendant_of(&EntityUID::with_eid("F")));
735        assert!(a.is_descendant_of(&EntityUID::with_eid("G")));
736        assert!(a.is_descendant_of(&EntityUID::with_eid("H")));
737        assert!(b.is_descendant_of(&EntityUID::with_eid("E")));
738        assert!(b.is_descendant_of(&EntityUID::with_eid("H")));
739        assert!(f.is_descendant_of(&EntityUID::with_eid("E")));
740        assert!(f.is_descendant_of(&EntityUID::with_eid("H")));
741        // but it shouldn't have added these other edges
742        assert!(!b.is_descendant_of(&EntityUID::with_eid("F")));
743        assert!(!b.is_descendant_of(&EntityUID::with_eid("G")));
744        assert!(!f.is_descendant_of(&EntityUID::with_eid("C")));
745        assert!(!f.is_descendant_of(&EntityUID::with_eid("D")));
746        // now it should pass TC enforcement
747        assert!(enforce_tc(&entities).is_ok());
748        // passes cycle check after TC enforcement
749        assert!(enforce_dag_from_tc(&entities).is_ok());
750    }
751
752    #[test]
753    fn already_edges() {
754        // start with this, which already includes some (but not all) transitive
755        // edges
756        //     B --> E
757        //   /  \   /
758        // A ---> C
759        //   \   /
760        //     D --> F
761        let mut a = Entity::with_uid(EntityUID::with_eid("A"));
762        a.add_parent(EntityUID::with_eid("B"));
763        a.add_parent(EntityUID::with_eid("C"));
764        a.add_parent(EntityUID::with_eid("D"));
765        let mut b = Entity::with_uid(EntityUID::with_eid("B"));
766        b.add_parent(EntityUID::with_eid("C"));
767        b.add_parent(EntityUID::with_eid("E"));
768        let mut c = Entity::with_uid(EntityUID::with_eid("C"));
769        c.add_parent(EntityUID::with_eid("E"));
770        let mut d = Entity::with_uid(EntityUID::with_eid("D"));
771        d.add_parent(EntityUID::with_eid("C"));
772        d.add_parent(EntityUID::with_eid("F"));
773        let e = Entity::with_uid(EntityUID::with_eid("E"));
774        let f = Entity::with_uid(EntityUID::with_eid("F"));
775        let mut entities = HashMap::from([
776            (a.uid().clone(), a),
777            (b.uid().clone(), b),
778            (c.uid().clone(), c),
779            (d.uid().clone(), d),
780            (e.uid().clone(), e),
781            (f.uid().clone(), f),
782        ]);
783        // currently doesn't pass TC enforcement
784        assert!(enforce_tc(&entities).is_err());
785        // compute TC
786        compute_tc(&mut entities, false).expect("Failed to compute transitive closure");
787        let a = &entities[&EntityUID::with_eid("A")];
788        let b = &entities[&EntityUID::with_eid("B")];
789        let c = &entities[&EntityUID::with_eid("C")];
790        let d = &entities[&EntityUID::with_eid("D")];
791        // should have added many edges which we check for
792        assert!(a.is_descendant_of(&EntityUID::with_eid("E")));
793        assert!(a.is_descendant_of(&EntityUID::with_eid("F")));
794        assert!(d.is_descendant_of(&EntityUID::with_eid("E")));
795        // but it shouldn't have added these other edges
796        assert!(!b.is_descendant_of(&EntityUID::with_eid("F")));
797        assert!(!c.is_descendant_of(&EntityUID::with_eid("F")));
798        // now it should pass TC enforcement
799        assert!(enforce_tc(&entities).is_ok());
800        // passes cycle check after TC enforcement
801        assert!(enforce_dag_from_tc(&entities).is_ok());
802    }
803
804    #[test]
805    fn disjoint_dag() {
806        // graph with disconnected components:
807        //     B -> C
808        //
809        // A      D -> E -> H
810        //   \
811        //     F -> G
812        let mut a = Entity::with_uid(EntityUID::with_eid("A"));
813        a.add_parent(EntityUID::with_eid("F"));
814        let mut b = Entity::with_uid(EntityUID::with_eid("B"));
815        b.add_parent(EntityUID::with_eid("C"));
816        let c = Entity::with_uid(EntityUID::with_eid("C"));
817        let mut d = Entity::with_uid(EntityUID::with_eid("D"));
818        d.add_parent(EntityUID::with_eid("E"));
819        let mut e = Entity::with_uid(EntityUID::with_eid("E"));
820        e.add_parent(EntityUID::with_eid("H"));
821        let mut f = Entity::with_uid(EntityUID::with_eid("F"));
822        f.add_parent(EntityUID::with_eid("G"));
823        let g = Entity::with_uid(EntityUID::with_eid("G"));
824        let h = Entity::with_uid(EntityUID::with_eid("H"));
825        let mut entities = HashMap::from([
826            (a.uid().clone(), a),
827            (b.uid().clone(), b),
828            (c.uid().clone(), c),
829            (d.uid().clone(), d),
830            (e.uid().clone(), e),
831            (f.uid().clone(), f),
832            (g.uid().clone(), g),
833            (h.uid().clone(), h),
834        ]);
835        // currently doesn't pass TC enforcement
836        assert!(enforce_tc(&entities).is_err());
837        // compute TC
838        compute_tc(&mut entities, false).expect("Failed to compute transitive closure");
839        let a = &entities[&EntityUID::with_eid("A")];
840        let b = &entities[&EntityUID::with_eid("B")];
841        let d = &entities[&EntityUID::with_eid("D")];
842        let f = &entities[&EntityUID::with_eid("F")];
843        // should have added two edges
844        assert!(a.is_descendant_of(&EntityUID::with_eid("G")));
845        assert!(d.is_descendant_of(&EntityUID::with_eid("H")));
846        // but it shouldn't have added these other edges
847        assert!(!a.is_descendant_of(&EntityUID::with_eid("C")));
848        assert!(!a.is_descendant_of(&EntityUID::with_eid("D")));
849        assert!(!a.is_descendant_of(&EntityUID::with_eid("E")));
850        assert!(!a.is_descendant_of(&EntityUID::with_eid("H")));
851        assert!(!b.is_descendant_of(&EntityUID::with_eid("E")));
852        assert!(!b.is_descendant_of(&EntityUID::with_eid("H")));
853        assert!(!f.is_descendant_of(&EntityUID::with_eid("E")));
854        assert!(!f.is_descendant_of(&EntityUID::with_eid("H")));
855        assert!(!b.is_descendant_of(&EntityUID::with_eid("F")));
856        assert!(!b.is_descendant_of(&EntityUID::with_eid("G")));
857        assert!(!f.is_descendant_of(&EntityUID::with_eid("C")));
858        assert!(!f.is_descendant_of(&EntityUID::with_eid("D")));
859        // now it should pass TC enforcement
860        assert!(enforce_tc(&entities).is_ok());
861        // passes cycle check after TC enforcement
862        assert!(enforce_dag_from_tc(&entities).is_ok());
863    }
864
865    #[test]
866    fn trivial_cycle() {
867        // this graph is invalid, but we want to still have some reasonable behavior
868        // if we encounter it (and definitely not panic, infinitely recurse, etc)
869        //
870        // A -> B -> B
871        let mut a = Entity::with_uid(EntityUID::with_eid("A"));
872        a.add_parent(EntityUID::with_eid("B"));
873        let mut b = Entity::with_uid(EntityUID::with_eid("B"));
874        b.add_parent(EntityUID::with_eid("B"));
875        let mut entities = HashMap::from([(a.uid().clone(), a), (b.uid().clone(), b)]);
876        // computing TC should succeed without panicking, infinitely recursing, etc
877        compute_tc(&mut entities, false).expect("Failed to compute transitive closure");
878        // fails cycle check
879        match enforce_dag_from_tc(&entities) {
880            Ok(_) => panic!("enforce_dag_from_tc should have returned an error"),
881            Err(TcError::HasCycle(err)) => {
882                assert!(err.vertex_with_loop() == &EntityUID::with_eid("B"));
883            }
884            Err(_) => panic!("Unexpected error in enforce_dag_from_tc"),
885        }
886        let a = &entities[&EntityUID::with_eid("A")];
887        let b = &entities[&EntityUID::with_eid("B")];
888        // we check that the A -> B edge still exists
889        assert!(a.is_descendant_of(&EntityUID::with_eid("B")));
890        // but it shouldn't have added a B -> A edge
891        assert!(!b.is_descendant_of(&EntityUID::with_eid("A")));
892        // we also check that, whatever compute_tc did with this invalid input, the
893        // final result still passes enforce_tc
894        assert!(enforce_tc(&entities).is_ok());
895        // still fails cycle check
896        match enforce_dag_from_tc(&entities) {
897            Ok(_) => panic!("enforce_dag_from_tc should have returned an error"),
898            Err(TcError::HasCycle(err)) => {
899                assert!(err.vertex_with_loop() == &EntityUID::with_eid("B"));
900            }
901            Err(_) => panic!("Unexpected error in enforce_dag_from_tc"),
902        }
903    }
904
905    #[test]
906    fn nontrivial_cycle() {
907        // this graph is invalid, but we want to still have some reasonable behavior
908        // if we encounter it (and definitely not panic, infinitely recurse, etc)
909        //
910        //          D
911        //        /
912        // A -> B -> C -> A
913        let mut a = Entity::with_uid(EntityUID::with_eid("A"));
914        a.add_parent(EntityUID::with_eid("B"));
915        let mut b = Entity::with_uid(EntityUID::with_eid("B"));
916        b.add_parent(EntityUID::with_eid("C"));
917        b.add_parent(EntityUID::with_eid("D"));
918        let mut c = Entity::with_uid(EntityUID::with_eid("C"));
919        c.add_parent(EntityUID::with_eid("A"));
920        let d = Entity::with_uid(EntityUID::with_eid("D"));
921        let mut entities = HashMap::from([
922            (a.uid().clone(), a),
923            (b.uid().clone(), b),
924            (c.uid().clone(), c),
925            (d.uid().clone(), d),
926        ]);
927        // computing TC should succeed without panicking, infinitely recursing, etc
928        compute_tc_internal(
929            entities
930                .keys()
931                .map(EntityUID::clone)
932                .collect::<Vec<EntityUID>>()
933                .into_iter(),
934            &mut entities,
935            HashSet::new(),
936            true,
937        );
938        // fails cycle check
939        match enforce_dag_from_tc(&entities) {
940            Ok(_) => panic!("enforce_dag_from_tc should have returned an error"),
941            Err(TcError::HasCycle(err)) => {
942                assert!(
943                    err.vertex_with_loop() == &EntityUID::with_eid("A")
944                        || err.vertex_with_loop() == &EntityUID::with_eid("B")
945                        || err.vertex_with_loop() == &EntityUID::with_eid("C")
946                );
947            }
948            Err(_) => panic!("Unexpected error in enforce_dag_from_tc"),
949        }
950        //TC tests
951        let a = &entities[&EntityUID::with_eid("A")];
952        let b = &entities[&EntityUID::with_eid("B")];
953        // we should have added A -> C and A -> D edges, at least
954        assert!(a.is_descendant_of(&EntityUID::with_eid("C")));
955        assert!(a.is_descendant_of(&EntityUID::with_eid("D")));
956        // and we should also have added a B -> A edge
957        assert!(b.is_descendant_of(&EntityUID::with_eid("A")));
958        // we also check that, whatever compute_tc did with this invalid input, the
959        // final result still passes enforce_tc
960        assert!(enforce_tc(&entities).is_ok());
961        // still fails cycle check
962        match enforce_dag_from_tc(&entities) {
963            Ok(_) => panic!("enforce_dag_from_tc should have returned an error"),
964            Err(TcError::HasCycle(err)) => {
965                assert!(
966                    err.vertex_with_loop() == &EntityUID::with_eid("A")
967                        || err.vertex_with_loop() == &EntityUID::with_eid("B")
968                        || err.vertex_with_loop() == &EntityUID::with_eid("C")
969                );
970            }
971            Err(_) => panic!("Unexpected error in enforce_dag_from_tc"),
972        }
973    }
974
975    #[test]
976    fn disjoint_cycles() {
977        // graph with disconnected components including cycles:
978        //     B -> C -> B
979        //
980        // A      D -> E -> H -> D
981        //   \
982        //     F -> G
983        let mut a = Entity::with_uid(EntityUID::with_eid("A"));
984        a.add_parent(EntityUID::with_eid("F"));
985        let mut b = Entity::with_uid(EntityUID::with_eid("B"));
986        b.add_parent(EntityUID::with_eid("C"));
987        let mut c: Entity = Entity::with_uid(EntityUID::with_eid("C"));
988        c.add_parent(EntityUID::with_eid("B"));
989        let mut d = Entity::with_uid(EntityUID::with_eid("D"));
990        d.add_parent(EntityUID::with_eid("E"));
991        let mut e = Entity::with_uid(EntityUID::with_eid("E"));
992        e.add_parent(EntityUID::with_eid("H"));
993        let mut f = Entity::with_uid(EntityUID::with_eid("F"));
994        f.add_parent(EntityUID::with_eid("G"));
995        let g = Entity::with_uid(EntityUID::with_eid("G"));
996        let mut h = Entity::with_uid(EntityUID::with_eid("H"));
997        h.add_parent(EntityUID::with_eid("D"));
998        let mut entities = HashMap::from([
999            (a.uid().clone(), a),
1000            (b.uid().clone(), b),
1001            (c.uid().clone(), c),
1002            (d.uid().clone(), d),
1003            (e.uid().clone(), e),
1004            (f.uid().clone(), f),
1005            (g.uid().clone(), g),
1006            (h.uid().clone(), h),
1007        ]);
1008        // currently doesn't pass TC enforcement
1009        assert!(enforce_tc(&entities).is_err());
1010        // compute TC
1011        compute_tc_internal(
1012            entities
1013                .keys()
1014                .map(EntityUID::clone)
1015                .collect::<Vec<EntityUID>>()
1016                .into_iter(),
1017            &mut entities,
1018            HashSet::new(),
1019            true,
1020        );
1021        // now it should pass TC enforcement
1022        assert!(enforce_tc(&entities).is_ok());
1023        // still fails cycle check
1024        match enforce_dag_from_tc(&entities) {
1025            Ok(_) => panic!("enforce_dag_from_tc should have returned an error"),
1026            Err(TcError::HasCycle(err)) => {
1027                // two possible cycles
1028                assert!(
1029                    err.vertex_with_loop() == &EntityUID::with_eid("B")
1030                        || err.vertex_with_loop() == &EntityUID::with_eid("C")
1031                        || err.vertex_with_loop() == &EntityUID::with_eid("D")
1032                        || err.vertex_with_loop() == &EntityUID::with_eid("E")
1033                        || err.vertex_with_loop() == &EntityUID::with_eid("H")
1034                );
1035            }
1036            Err(_) => panic!("Unexpected error in enforce_dag_from_tc"),
1037        }
1038    }
1039
1040    #[test]
1041    fn intersecting_cycles() {
1042        // graph with two intersecting cycles:
1043        //  A -> B -> C -> E -
1044        //  ^    ^         ^  |
1045        //  |    |         |  |
1046        //  |    |        /   |
1047        //   \  /        /    |
1048        //    D ------>F      |
1049        //    ^               |
1050        //    |___------------
1051        let mut a = Entity::with_uid(EntityUID::with_eid("A"));
1052        a.add_parent(EntityUID::with_eid("B"));
1053        let mut b = Entity::with_uid(EntityUID::with_eid("B"));
1054        b.add_parent(EntityUID::with_eid("C"));
1055        let mut c = Entity::with_uid(EntityUID::with_eid("C"));
1056        c.add_parent(EntityUID::with_eid("E"));
1057        let mut d = Entity::with_uid(EntityUID::with_eid("D"));
1058        d.add_parent(EntityUID::with_eid("A"));
1059        d.add_parent(EntityUID::with_eid("B"));
1060        d.add_parent(EntityUID::with_eid("F"));
1061        let mut e = Entity::with_uid(EntityUID::with_eid("E"));
1062        e.add_parent(EntityUID::with_eid("D"));
1063        let mut f = Entity::with_uid(EntityUID::with_eid("F"));
1064        f.add_parent(EntityUID::with_eid("E"));
1065        let mut entities = HashMap::from([
1066            (a.uid().clone(), a),
1067            (b.uid().clone(), b),
1068            (c.uid().clone(), c),
1069            (d.uid().clone(), d),
1070            (e.uid().clone(), e),
1071            (f.uid().clone(), f),
1072        ]);
1073        // fails TC enforcement
1074        assert!(enforce_tc(&entities).is_err());
1075        // compute TC
1076        cyclic_tc(&mut entities);
1077        assert!(enforce_tc(&entities).is_ok());
1078        // the graph may or may not pass the TC check but it will always fail cycle check
1079        match enforce_dag_from_tc(&entities) {
1080            Ok(_) => panic!("enforce_dag_from_tc should have returned an error"),
1081            Err(TcError::HasCycle(_)) => (), // Every vertex is in a cycle
1082            Err(_) => panic!("Unexpected error in enforce_dag_from_tc"),
1083        }
1084    }
1085
1086    #[test]
1087    fn updated() {
1088        // start with A -> B -> C
1089        let mut a = Entity::with_uid(EntityUID::with_eid("A"));
1090        a.add_parent(EntityUID::with_eid("B"));
1091        let mut b = Entity::with_uid(EntityUID::with_eid("B"));
1092        b.add_parent(EntityUID::with_eid("C"));
1093        let c = Entity::with_uid(EntityUID::with_eid("C"));
1094        let mut entities = HashMap::from([
1095            (a.uid().clone(), a),
1096            (b.uid().clone(), b),
1097            (c.uid().clone(), c),
1098        ]);
1099        // currently doesn't pass TC enforcement
1100        assert!(enforce_tc(&entities).is_err());
1101        // compute TC
1102        let all_ids: Vec<_> = entities.keys().cloned().collect();
1103        compute_tc_internal(all_ids.into_iter(), &mut entities, HashSet::new(), false);
1104        let a = &entities[&EntityUID::with_eid("A")];
1105        let b = &entities[&EntityUID::with_eid("B")];
1106        let c = &entities[&EntityUID::with_eid("C")];
1107        // should have added the A -> C edge
1108        assert!(a.has_edge_to(&EntityUID::with_eid("C")));
1109        // but we shouldn't have added other edges, like B -> A or C -> A
1110        assert!(!b.has_edge_to(&EntityUID::with_eid("A")));
1111        assert!(!c.has_edge_to(&EntityUID::with_eid("A")));
1112        // now it should pass TC enforcement
1113        assert!(enforce_tc(&entities).is_ok());
1114        // passes cycle check after TC enforcement
1115        assert!(enforce_dag_from_tc(&entities).is_ok());
1116        // D doesn't exist yet
1117        assert!(!a.has_edge_to(&EntityUID::with_eid("D")));
1118
1119        // change from B -> C to B -> D
1120        // Recreate A with only its original parent (B) to clear stale TC edges
1121        let mut a = Entity::with_uid(EntityUID::with_eid("A"));
1122        a.add_parent(EntityUID::with_eid("B"));
1123        let mut b = Entity::with_uid(EntityUID::with_eid("B"));
1124        b.add_parent(EntityUID::with_eid("D"));
1125        let d = Entity::with_uid(EntityUID::with_eid("D"));
1126        entities.insert(a.uid().clone(), a);
1127        entities.insert(b.uid().clone(), b);
1128        entities.insert(d.uid().clone(), d);
1129
1130        // currently doesn't pass TC enforcement
1131        assert!(enforce_tc(&entities).is_err());
1132        // compute TC
1133        let all_ids: Vec<_> = entities.keys().cloned().collect();
1134        compute_tc_internal(all_ids.into_iter(), &mut entities, HashSet::new(), false);
1135        let a = &entities[&EntityUID::with_eid("A")];
1136        let b = &entities[&EntityUID::with_eid("B")];
1137        let c = &entities[&EntityUID::with_eid("C")];
1138        let d = &entities[&EntityUID::with_eid("D")];
1139        // should have added the A -> D edge
1140        assert!(a.has_edge_to(&EntityUID::with_eid("D")));
1141        // should not have the A -> C edge
1142        assert!(!a.has_edge_to(&EntityUID::with_eid("C")));
1143        assert!(!b.has_edge_to(&EntityUID::with_eid("C")));
1144        // but we shouldn't have added other edges, like B -> A or C -> A
1145        assert!(!b.has_edge_to(&EntityUID::with_eid("A")));
1146        assert!(!c.has_edge_to(&EntityUID::with_eid("A")));
1147        assert!(!d.has_edge_to(&EntityUID::with_eid("A")));
1148        // now it should pass TC enforcement
1149        assert!(enforce_tc(&entities).is_ok());
1150        // passes cycle check after TC enforcement
1151        assert!(enforce_dag_from_tc(&entities).is_ok());
1152    }
1153
1154    /// Helper: collect the out-edge set for every node
1155    fn snapshot(entities: &HashMap<EntityUID, Entity>) -> HashMap<EntityUID, HashSet<EntityUID>> {
1156        entities
1157            .iter()
1158            .map(|(k, v)| (k.clone(), v.out_edges().cloned().collect()))
1159            .collect()
1160    }
1161
1162    /// Build a fresh copy of the entity map with only direct parent edges
1163    fn fresh_copy(entities: &HashMap<EntityUID, Entity>) -> HashMap<EntityUID, Entity> {
1164        entities
1165            .values()
1166            .map(|e| {
1167                let mut fresh = Entity::with_uid(e.uid().clone());
1168                for p in e.parents() {
1169                    fresh.add_parent(p.clone());
1170                }
1171                (fresh.uid().clone(), fresh)
1172            })
1173            .collect()
1174    }
1175
1176    /// Use `cyclic_tc` as a test oracle: verify that `cyclic_tc` produces
1177    /// identical results to `compute_tc` on DAGs.
1178    #[test]
1179    fn cyclic_tc_oracle_on_dags() {
1180        for (name, base) in &dag_test_graphs() {
1181            let mut via_compute = fresh_copy(base);
1182            compute_tc(&mut via_compute, false).expect("compute_tc failed");
1183
1184            let mut via_cyclic = fresh_copy(base);
1185            cyclic_tc(&mut via_cyclic);
1186
1187            assert_eq!(
1188                snapshot(&via_compute),
1189                snapshot(&via_cyclic),
1190                "TC mismatch on DAG '{name}'"
1191            );
1192        }
1193    }
1194
1195    /// Use `cyclic_tc` as a test oracle on cyclic graphs: verify exact TC
1196    /// and that enforce_tc passes afterwards.
1197    #[test]
1198    fn cyclic_tc_oracle_on_cycles() {
1199        for (name, base) in &cyclic_test_graphs() {
1200            let mut entities = fresh_copy(base);
1201            cyclic_tc(&mut entities);
1202
1203            assert!(
1204                enforce_tc(&entities).is_ok(),
1205                "enforce_tc failed after cyclic_tc on '{name}'"
1206            );
1207            // All cyclic test graphs contain cycles, so DAG check must fail
1208            assert!(
1209                enforce_dag_from_tc(&entities).is_err(),
1210                "expected cycle detection on '{name}'"
1211            );
1212        }
1213    }
1214
1215    /// Verify that `compute_tc(..., false)` produces the same result as
1216    /// `cyclic_tc` on cyclic graphs (regression: the old single-pass algorithm
1217    /// produced incomplete closures for cycles of length >= 3).
1218    #[test]
1219    fn compute_tc_no_dag_matches_cyclic_tc_on_cycles() {
1220        for (name, base) in &cyclic_test_graphs() {
1221            let mut via_compute = fresh_copy(base);
1222            compute_tc(&mut via_compute, false)
1223                .unwrap_or_else(|_| panic!("compute_tc failed on '{name}'"));
1224
1225            let mut via_cyclic = fresh_copy(base);
1226            cyclic_tc(&mut via_cyclic);
1227
1228            let snap_compute = snapshot(&via_compute);
1229            let snap_cyclic = snapshot(&via_cyclic);
1230            assert_eq!(
1231                snap_compute, snap_cyclic,
1232                "compute_tc(false) differs from cyclic_tc on '{name}'"
1233            );
1234        }
1235    }
1236
1237    #[test]
1238    fn self_loop_with_grandchild() {
1239        let mut a = Entity::with_uid(EntityUID::with_eid("A"));
1240        a.add_parent(EntityUID::with_eid("A"));
1241        a.add_parent(EntityUID::with_eid("B"));
1242        let mut b = Entity::with_uid(EntityUID::with_eid("B"));
1243        b.add_parent(EntityUID::with_eid("C"));
1244        let c = Entity::with_uid(EntityUID::with_eid("C"));
1245        let mut entities = HashMap::from([
1246            (a.uid().clone(), a),
1247            (b.uid().clone(), b),
1248            (c.uid().clone(), c),
1249        ]);
1250        compute_tc(&mut entities, false).expect("compute_tc failed");
1251        assert!(entities[&EntityUID::with_eid("A")].has_edge_to(&EntityUID::with_eid("C")));
1252        assert!(enforce_tc(&entities).is_ok());
1253    }
1254
1255    fn dag_test_graphs() -> Vec<(&'static str, HashMap<EntityUID, Entity>)> {
1256        vec![
1257            ("A->B->C", {
1258                let mut a = Entity::with_uid(EntityUID::with_eid("A"));
1259                a.add_parent(EntityUID::with_eid("B"));
1260                let mut b = Entity::with_uid(EntityUID::with_eid("B"));
1261                b.add_parent(EntityUID::with_eid("C"));
1262                let c = Entity::with_uid(EntityUID::with_eid("C"));
1263                HashMap::from([
1264                    (a.uid().clone(), a),
1265                    (b.uid().clone(), b),
1266                    (c.uid().clone(), c),
1267                ])
1268            }),
1269            ("A->B->C->D->E", {
1270                let mut a = Entity::with_uid(EntityUID::with_eid("A"));
1271                a.add_parent(EntityUID::with_eid("B"));
1272                let mut b = Entity::with_uid(EntityUID::with_eid("B"));
1273                b.add_parent(EntityUID::with_eid("C"));
1274                let mut c = Entity::with_uid(EntityUID::with_eid("C"));
1275                c.add_parent(EntityUID::with_eid("D"));
1276                let mut d = Entity::with_uid(EntityUID::with_eid("D"));
1277                d.add_parent(EntityUID::with_eid("E"));
1278                let e = Entity::with_uid(EntityUID::with_eid("E"));
1279                HashMap::from([
1280                    (a.uid().clone(), a),
1281                    (b.uid().clone(), b),
1282                    (c.uid().clone(), c),
1283                    (d.uid().clone(), d),
1284                    (e.uid().clone(), e),
1285                ])
1286            }),
1287            ("multi_parents", {
1288                //     B -> C
1289                //   /
1290                // A
1291                //   \
1292                //     D -> E
1293                let mut a = Entity::with_uid(EntityUID::with_eid("A"));
1294                a.add_parent(EntityUID::with_eid("B"));
1295                a.add_parent(EntityUID::with_eid("D"));
1296                let mut b = Entity::with_uid(EntityUID::with_eid("B"));
1297                b.add_parent(EntityUID::with_eid("C"));
1298                let c = Entity::with_uid(EntityUID::with_eid("C"));
1299                let mut d = Entity::with_uid(EntityUID::with_eid("D"));
1300                d.add_parent(EntityUID::with_eid("E"));
1301                let e = Entity::with_uid(EntityUID::with_eid("E"));
1302                HashMap::from([
1303                    (a.uid().clone(), a),
1304                    (b.uid().clone(), b),
1305                    (c.uid().clone(), c),
1306                    (d.uid().clone(), d),
1307                    (e.uid().clone(), e),
1308                ])
1309            }),
1310            ("diamond", {
1311                // A -> B -> D
1312                // A -> C -> D
1313                let mut a = Entity::with_uid(EntityUID::with_eid("A"));
1314                a.add_parent(EntityUID::with_eid("B"));
1315                a.add_parent(EntityUID::with_eid("C"));
1316                let mut b = Entity::with_uid(EntityUID::with_eid("B"));
1317                b.add_parent(EntityUID::with_eid("D"));
1318                let mut c = Entity::with_uid(EntityUID::with_eid("C"));
1319                c.add_parent(EntityUID::with_eid("D"));
1320                let d = Entity::with_uid(EntityUID::with_eid("D"));
1321                HashMap::from([
1322                    (a.uid().clone(), a),
1323                    (b.uid().clone(), b),
1324                    (c.uid().clone(), c),
1325                    (d.uid().clone(), d),
1326                ])
1327            }),
1328            ("dag_with_join", {
1329                //     B -> C
1330                //   /  \
1331                // A      D -> E -> H
1332                //   \        /
1333                //     F -> G
1334                let mut a = Entity::with_uid(EntityUID::with_eid("A"));
1335                a.add_parent(EntityUID::with_eid("B"));
1336                a.add_parent(EntityUID::with_eid("F"));
1337                let mut b = Entity::with_uid(EntityUID::with_eid("B"));
1338                b.add_parent(EntityUID::with_eid("C"));
1339                b.add_parent(EntityUID::with_eid("D"));
1340                let c = Entity::with_uid(EntityUID::with_eid("C"));
1341                let mut d = Entity::with_uid(EntityUID::with_eid("D"));
1342                d.add_parent(EntityUID::with_eid("E"));
1343                let mut e = Entity::with_uid(EntityUID::with_eid("E"));
1344                e.add_parent(EntityUID::with_eid("H"));
1345                let mut f = Entity::with_uid(EntityUID::with_eid("F"));
1346                f.add_parent(EntityUID::with_eid("G"));
1347                let mut g = Entity::with_uid(EntityUID::with_eid("G"));
1348                g.add_parent(EntityUID::with_eid("E"));
1349                let h = Entity::with_uid(EntityUID::with_eid("H"));
1350                HashMap::from([
1351                    (a.uid().clone(), a),
1352                    (b.uid().clone(), b),
1353                    (c.uid().clone(), c),
1354                    (d.uid().clone(), d),
1355                    (e.uid().clone(), e),
1356                    (f.uid().clone(), f),
1357                    (g.uid().clone(), g),
1358                    (h.uid().clone(), h),
1359                ])
1360            }),
1361            ("already_edges", {
1362                //     B --> E
1363                //   /  \   /
1364                // A ---> C
1365                //   \   /
1366                //     D --> F
1367                let mut a = Entity::with_uid(EntityUID::with_eid("A"));
1368                a.add_parent(EntityUID::with_eid("B"));
1369                a.add_parent(EntityUID::with_eid("C"));
1370                a.add_parent(EntityUID::with_eid("D"));
1371                let mut b = Entity::with_uid(EntityUID::with_eid("B"));
1372                b.add_parent(EntityUID::with_eid("C"));
1373                b.add_parent(EntityUID::with_eid("E"));
1374                let mut c = Entity::with_uid(EntityUID::with_eid("C"));
1375                c.add_parent(EntityUID::with_eid("E"));
1376                let mut d = Entity::with_uid(EntityUID::with_eid("D"));
1377                d.add_parent(EntityUID::with_eid("C"));
1378                d.add_parent(EntityUID::with_eid("F"));
1379                let e = Entity::with_uid(EntityUID::with_eid("E"));
1380                let f = Entity::with_uid(EntityUID::with_eid("F"));
1381                HashMap::from([
1382                    (a.uid().clone(), a),
1383                    (b.uid().clone(), b),
1384                    (c.uid().clone(), c),
1385                    (d.uid().clone(), d),
1386                    (e.uid().clone(), e),
1387                    (f.uid().clone(), f),
1388                ])
1389            }),
1390            ("disjoint", {
1391                //     B -> C
1392                //
1393                // A      D -> E -> H
1394                //   \
1395                //     F -> G
1396                let mut a = Entity::with_uid(EntityUID::with_eid("A"));
1397                a.add_parent(EntityUID::with_eid("F"));
1398                let mut b = Entity::with_uid(EntityUID::with_eid("B"));
1399                b.add_parent(EntityUID::with_eid("C"));
1400                let c = Entity::with_uid(EntityUID::with_eid("C"));
1401                let mut d = Entity::with_uid(EntityUID::with_eid("D"));
1402                d.add_parent(EntityUID::with_eid("E"));
1403                let mut e = Entity::with_uid(EntityUID::with_eid("E"));
1404                e.add_parent(EntityUID::with_eid("H"));
1405                let mut f = Entity::with_uid(EntityUID::with_eid("F"));
1406                f.add_parent(EntityUID::with_eid("G"));
1407                let g = Entity::with_uid(EntityUID::with_eid("G"));
1408                let h = Entity::with_uid(EntityUID::with_eid("H"));
1409                HashMap::from([
1410                    (a.uid().clone(), a),
1411                    (b.uid().clone(), b),
1412                    (c.uid().clone(), c),
1413                    (d.uid().clone(), d),
1414                    (e.uid().clone(), e),
1415                    (f.uid().clone(), f),
1416                    (g.uid().clone(), g),
1417                    (h.uid().clone(), h),
1418                ])
1419            }),
1420        ]
1421    }
1422
1423    fn cyclic_test_graphs() -> Vec<(&'static str, HashMap<EntityUID, Entity>)> {
1424        vec![
1425            ("trivial_cycle", {
1426                // A -> B -> B
1427                let mut a = Entity::with_uid(EntityUID::with_eid("A"));
1428                a.add_parent(EntityUID::with_eid("B"));
1429                let mut b = Entity::with_uid(EntityUID::with_eid("B"));
1430                b.add_parent(EntityUID::with_eid("B"));
1431                HashMap::from([(a.uid().clone(), a), (b.uid().clone(), b)])
1432            }),
1433            ("simple_cycle", {
1434                // A -> B -> C -> A, B -> D
1435                let mut a = Entity::with_uid(EntityUID::with_eid("A"));
1436                a.add_parent(EntityUID::with_eid("B"));
1437                let mut b = Entity::with_uid(EntityUID::with_eid("B"));
1438                b.add_parent(EntityUID::with_eid("C"));
1439                b.add_parent(EntityUID::with_eid("D"));
1440                let mut c = Entity::with_uid(EntityUID::with_eid("C"));
1441                c.add_parent(EntityUID::with_eid("A"));
1442                let d = Entity::with_uid(EntityUID::with_eid("D"));
1443                HashMap::from([
1444                    (a.uid().clone(), a),
1445                    (b.uid().clone(), b),
1446                    (c.uid().clone(), c),
1447                    (d.uid().clone(), d),
1448                ])
1449            }),
1450            ("disjoint_cycles", {
1451                // B -> C -> B,  D -> E -> H -> D
1452                // A -> F -> G
1453                let mut a = Entity::with_uid(EntityUID::with_eid("A"));
1454                a.add_parent(EntityUID::with_eid("F"));
1455                let mut b = Entity::with_uid(EntityUID::with_eid("B"));
1456                b.add_parent(EntityUID::with_eid("C"));
1457                let mut c = Entity::with_uid(EntityUID::with_eid("C"));
1458                c.add_parent(EntityUID::with_eid("B"));
1459                let mut d = Entity::with_uid(EntityUID::with_eid("D"));
1460                d.add_parent(EntityUID::with_eid("E"));
1461                let mut e = Entity::with_uid(EntityUID::with_eid("E"));
1462                e.add_parent(EntityUID::with_eid("H"));
1463                let mut f = Entity::with_uid(EntityUID::with_eid("F"));
1464                f.add_parent(EntityUID::with_eid("G"));
1465                let g = Entity::with_uid(EntityUID::with_eid("G"));
1466                let mut h = Entity::with_uid(EntityUID::with_eid("H"));
1467                h.add_parent(EntityUID::with_eid("D"));
1468                HashMap::from([
1469                    (a.uid().clone(), a),
1470                    (b.uid().clone(), b),
1471                    (c.uid().clone(), c),
1472                    (d.uid().clone(), d),
1473                    (e.uid().clone(), e),
1474                    (f.uid().clone(), f),
1475                    (g.uid().clone(), g),
1476                    (h.uid().clone(), h),
1477                ])
1478            }),
1479            ("intersecting_cycles", {
1480                //  A -> B -> C -> E -> D -> A
1481                //  D -> B, D -> F -> E
1482                let mut a = Entity::with_uid(EntityUID::with_eid("A"));
1483                a.add_parent(EntityUID::with_eid("B"));
1484                let mut b = Entity::with_uid(EntityUID::with_eid("B"));
1485                b.add_parent(EntityUID::with_eid("C"));
1486                let mut c = Entity::with_uid(EntityUID::with_eid("C"));
1487                c.add_parent(EntityUID::with_eid("E"));
1488                let mut d = Entity::with_uid(EntityUID::with_eid("D"));
1489                d.add_parent(EntityUID::with_eid("A"));
1490                d.add_parent(EntityUID::with_eid("B"));
1491                d.add_parent(EntityUID::with_eid("F"));
1492                let mut e = Entity::with_uid(EntityUID::with_eid("E"));
1493                e.add_parent(EntityUID::with_eid("D"));
1494                let mut f = Entity::with_uid(EntityUID::with_eid("F"));
1495                f.add_parent(EntityUID::with_eid("E"));
1496                HashMap::from([
1497                    (a.uid().clone(), a),
1498                    (b.uid().clone(), b),
1499                    (c.uid().clone(), c),
1500                    (d.uid().clone(), d),
1501                    (e.uid().clone(), e),
1502                    (f.uid().clone(), f),
1503                ])
1504            }),
1505        ]
1506    }
1507
1508    #[test]
1509    fn add_ancestors_dangling_parent_adds_transitive() {
1510        // Setup A -> B -> C and A -> X
1511        let mut a = Entity::with_uid(EntityUID::with_eid("A"));
1512        a.add_parent(EntityUID::with_eid("B"));
1513        a.add_parent(EntityUID::with_eid("X"));
1514        let mut b = Entity::with_uid(EntityUID::with_eid("B"));
1515        b.add_parent(EntityUID::with_eid("C"));
1516        let c = Entity::with_uid(EntityUID::with_eid("C"));
1517        let mut entities = HashMap::from([
1518            (a.uid().clone(), a),
1519            (b.uid().clone(), b),
1520            (c.uid().clone(), c),
1521        ]);
1522        compute_tc(&mut entities, false).expect("compute_tc failed");
1523
1524        // Entity `X` is not in the map. This previously made `add_ancestors`
1525        // return early, without adding transitive ancestors, so we lost the
1526        // transitive edge A -> C
1527        let a = entities.get(&EntityUID::with_eid("A")).unwrap();
1528        assert!(a.is_descendant_of(&EntityUID::with_eid("C")),);
1529    }
1530
1531    #[test]
1532    fn repair_tc_no_dag_computes_tc() {
1533        // Setup A -> B and B -> C
1534        let mut a = Entity::with_uid(EntityUID::with_eid("A"));
1535        a.add_parent(EntityUID::with_eid("B"));
1536        let mut b = Entity::with_uid(EntityUID::with_eid("B"));
1537        b.add_parent(EntityUID::with_eid("C"));
1538        let c = Entity::with_uid(EntityUID::with_eid("C"));
1539        let mut entities = HashMap::from([
1540            (a.uid().clone(), a),
1541            (b.uid().clone(), b),
1542            (c.uid().clone(), c),
1543        ]);
1544        compute_tc(&mut entities, false).expect("initial compute_tc failed");
1545
1546        // Add D -> B and repair TC
1547        let mut d = Entity::with_uid(EntityUID::with_eid("D"));
1548        d.add_parent(EntityUID::with_eid("B"));
1549        entities.insert(d.uid().clone(), d);
1550        let nodes_to_fix = HashSet::from([EntityUID::with_eid("D")]);
1551        repair_tc(nodes_to_fix, &mut entities, false).expect("repair_tc failed");
1552
1553        // D should get an edge to C
1554        let d = entities.get(&EntityUID::with_eid("D")).unwrap();
1555        assert!(d.is_descendant_of(&EntityUID::with_eid("C")));
1556    }
1557
1558    /// Each test case: (name, initial edges, new edges, nodes to fix).
1559    /// The test builds the initial graph, computes TC, then resets the
1560    /// nodes_to_fix, re-adds all their direct edges (initial + new), calls
1561    /// `repair_tc`, and asserts the result matches a full `compute_tc` on
1562    /// the combined edge set.
1563    fn repair_tc_test_cases() -> Vec<(
1564        &'static str,
1565        Vec<(&'static str, &'static str)>,
1566        Vec<(&'static str, &'static str)>,
1567        Vec<&'static str>,
1568    )> {
1569        vec![
1570            (
1571                "add_leaf_node",
1572                // Initial: A -> B -> C
1573                vec![("A", "B"), ("B", "C")],
1574                // Add: D -> C
1575                vec![("D", "C")],
1576                // Fix: D
1577                vec!["D"],
1578            ),
1579            (
1580                "add_intermediate_edge",
1581                // Initial: A -> B -> C -> D
1582                vec![("A", "B"), ("B", "C"), ("C", "D")],
1583                // Add: A -> C (shortcut)
1584                vec![("A", "C")],
1585                // Fix: A
1586                vec!["A"],
1587            ),
1588            (
1589                "diamond_new_entry",
1590                // Initial: A -> B -> D, A -> C -> D
1591                vec![("A", "B"), ("A", "C"), ("B", "D"), ("C", "D")],
1592                // Add: E -> B
1593                vec![("E", "B")],
1594                // Fix: E
1595                vec!["E"],
1596            ),
1597            (
1598                "multiple_new_nodes",
1599                // Initial: A -> B -> C
1600                vec![("A", "B"), ("B", "C")],
1601                // Add: D -> A, E -> A
1602                vec![("D", "A"), ("E", "A")],
1603                // Fix: D, E
1604                vec!["D", "E"],
1605            ),
1606            (
1607                "deep_chain_extension",
1608                // Initial: A -> B -> C -> D -> E
1609                vec![("A", "B"), ("B", "C"), ("C", "D"), ("D", "E")],
1610                // Add: F -> A
1611                vec![("F", "A")],
1612                // Fix: F
1613                vec!["F"],
1614            ),
1615            (
1616                "wide_fan_in",
1617                // Initial: B -> D, C -> D
1618                vec![("B", "D"), ("C", "D")],
1619                // Add: A -> B, A -> C
1620                vec![("A", "B"), ("A", "C")],
1621                // Fix: A
1622                vec!["A"],
1623            ),
1624            (
1625                "bridge_disjoint_components",
1626                // Initial: A -> B, C -> D
1627                vec![("A", "B"), ("C", "D")],
1628                // Add: E -> B, E -> D
1629                vec![("E", "B"), ("E", "D")],
1630                // Fix: E
1631                vec!["E"],
1632            ),
1633            (
1634                "node_gains_second_parent",
1635                // Initial: A -> B -> C -> D
1636                vec![("A", "B"), ("B", "C"), ("C", "D")],
1637                // Add: B -> E, E -> D (B gets second path to D)
1638                vec![("B", "E"), ("E", "D")],
1639                // Fix: A, B, E
1640                // Must fix A too since it depends on B.
1641                // In entities.rs, this is correctly done: entities_touched is expanded
1642                // when any entitity has one of the touched nodes as an ancestor
1643                // See the corresponding test with same name in that file.
1644                vec!["A", "B", "E"],
1645            ),
1646        ]
1647    }
1648
1649    /// Regression test: `compute_tc` with `enforce_dag=false` on a cycle of
1650    /// length >= 3 must produce a complete transitive closure (every node in
1651    /// the cycle can reach every other node). Before the fix, the single-pass
1652    /// DFS algorithm produced incomplete, non-deterministic results for such
1653    /// cycles.
1654    #[test]
1655    fn cycle_length_3_enforce_dag_false() {
1656        // A -> B -> C -> A (cycle of length 3)
1657        let mut a = Entity::with_uid(EntityUID::with_eid("A"));
1658        a.add_parent(EntityUID::with_eid("B"));
1659        let mut b = Entity::with_uid(EntityUID::with_eid("B"));
1660        b.add_parent(EntityUID::with_eid("C"));
1661        let mut c = Entity::with_uid(EntityUID::with_eid("C"));
1662        c.add_parent(EntityUID::with_eid("A"));
1663        let mut entities = HashMap::from([
1664            (a.uid().clone(), a),
1665            (b.uid().clone(), b),
1666            (c.uid().clone(), c),
1667        ]);
1668        // With enforce_dag=false, compute_tc must not error (cycles allowed)
1669        compute_tc(&mut entities, false).expect("compute_tc should succeed with enforce_dag=false");
1670        // Every node in the cycle must be able to reach every other node
1671        let a = &entities[&EntityUID::with_eid("A")];
1672        let b = &entities[&EntityUID::with_eid("B")];
1673        let c = &entities[&EntityUID::with_eid("C")];
1674        assert!(a.is_descendant_of(&EntityUID::with_eid("B")));
1675        assert!(a.is_descendant_of(&EntityUID::with_eid("C")));
1676        assert!(b.is_descendant_of(&EntityUID::with_eid("A")));
1677        assert!(b.is_descendant_of(&EntityUID::with_eid("C")));
1678        assert!(c.is_descendant_of(&EntityUID::with_eid("A")));
1679        assert!(c.is_descendant_of(&EntityUID::with_eid("B")));
1680        // enforce_tc must pass (TC is complete)
1681        assert!(enforce_tc(&entities).is_ok());
1682    }
1683
1684    /// Runs all `repair_tc` test cases, comparing repair result against a
1685    /// full `compute_tc` on the same final graph.
1686    #[test]
1687    fn repair_tc_cases() {
1688        for (name, edges, new_edges, nodes_to_fix) in &repair_tc_test_cases() {
1689            // Collect all node ids
1690            let mut ids: HashSet<&str> = HashSet::new();
1691            for (s, d) in edges.iter().chain(new_edges.iter()) {
1692                ids.insert(s);
1693                ids.insert(d);
1694            }
1695
1696            // Build initial graph
1697            let mut entities: HashMap<EntityUID, Entity> = ids
1698                .iter()
1699                .map(|id| {
1700                    let e = Entity::with_uid(EntityUID::with_eid(id));
1701                    (e.uid().clone(), e)
1702                })
1703                .collect();
1704            for (src, dst) in edges {
1705                entities
1706                    .get_mut(&EntityUID::with_eid(src))
1707                    .unwrap()
1708                    .add_parent(EntityUID::with_eid(dst));
1709            }
1710
1711            // Compute initial TC
1712            compute_tc(&mut entities, false).expect("initial compute_tc failed");
1713
1714            // Reset nodes_to_fix and re-add their direct edges (initial + new)
1715            let fix_ids: HashSet<&str> = nodes_to_fix.iter().copied().collect();
1716            for id in &fix_ids {
1717                entities
1718                    .get_mut(&EntityUID::with_eid(id))
1719                    .unwrap()
1720                    .reset_edges();
1721            }
1722            for (src, dst) in edges.iter().chain(new_edges.iter()) {
1723                if fix_ids.contains(src) {
1724                    entities
1725                        .get_mut(&EntityUID::with_eid(src))
1726                        .unwrap()
1727                        .add_parent(EntityUID::with_eid(dst));
1728                }
1729            }
1730            // Insert any brand-new nodes that weren't in the initial graph
1731            for id in &fix_ids {
1732                entities
1733                    .entry(EntityUID::with_eid(id))
1734                    .or_insert_with(|| Entity::with_uid(EntityUID::with_eid(id)));
1735            }
1736
1737            // Repair TC
1738            let fix_set: HashSet<EntityUID> = nodes_to_fix
1739                .iter()
1740                .map(|id| EntityUID::with_eid(id))
1741                .collect();
1742            repair_tc(fix_set, &mut entities, false)
1743                .unwrap_or_else(|_| panic!("repair_tc failed on '{name}'"));
1744            let repaired = snapshot(&entities);
1745
1746            // Build expected: full TC from scratch on the combined edge set
1747            let mut expected_entities: HashMap<EntityUID, Entity> = ids
1748                .iter()
1749                .map(|id| {
1750                    let e = Entity::with_uid(EntityUID::with_eid(id));
1751                    (e.uid().clone(), e)
1752                })
1753                .collect();
1754            for (src, dst) in edges.iter().chain(new_edges.iter()) {
1755                expected_entities
1756                    .get_mut(&EntityUID::with_eid(src))
1757                    .unwrap()
1758                    .add_parent(EntityUID::with_eid(dst));
1759            }
1760            compute_tc(&mut expected_entities, false).expect("expected compute_tc failed");
1761            let expected = snapshot(&expected_entities);
1762
1763            assert_eq!(
1764                repaired, expected,
1765                "repair_tc result differs from full compute_tc on '{name}'"
1766            );
1767        }
1768    }
1769}