Skip to main content

luaur_analysis/functions/
drain.rs

1//! Faithful port of `Luau::detail::drain`
2//! (`Analysis/src/TopoSortStatements.cpp:404-497`).
3//!
4//! Drain `Q` until the target's `depends` arcs are satisfied. `target` is always
5//! added to the result.
6use crate::functions::is_block_terminator::is_block_terminator;
7use crate::functions::prune::prune;
8use crate::records::arcs::Arcs;
9use crate::records::node::Node;
10use crate::type_aliases::node_list::NodeList;
11use alloc::collections::{BTreeMap, BTreeSet};
12use luaur_ast::records::ast_stat::AstStat;
13use luaur_common::macros::luau_assert::LUAU_ASSERT;
14
15pub fn drain(q: &mut NodeList, result: &mut Vec<*mut AstStat>, target: *mut Node) {
16    // Trying to toposort a subgraph is a pretty big hassle. :(
17    // Some of the nodes in .depends and .provides aren't present in our subgraph
18
19    // std::map<Node*, Arcs> allArcs;
20    let mut all_arcs: BTreeMap<*mut Node, Arcs> = BTreeMap::new();
21
22    // `DenseHashSet<Node*> elements` — the set of nodes present in Q. In C++ it
23    // is (redundantly) rebuilt once per outer iteration; it is constant, so we
24    // build it a single time.
25    let elements: BTreeSet<*mut Node> = q.iter().copied().collect();
26
27    // for (auto& node : Q) { ... copy connectivity filtered to Q ... }
28    for &node_ptr in q.iter() {
29        let mut arcs = Arcs::new();
30        let node = unsafe { &*node_ptr };
31
32        for &dep in node.depends.iter() {
33            if elements.contains(&dep) {
34                arcs.depends.insert(dep);
35            }
36        }
37        for &prov in node.provides.iter() {
38            if elements.contains(&prov) {
39                arcs.provides.insert(prov);
40            }
41        }
42
43        all_arcs.insert(node_ptr, arcs);
44    }
45
46    // while (!Q.empty())
47    while !q.is_empty() {
48        // if (target && target->depends.empty()) { prune(target); push; return; }
49        if !target.is_null() && unsafe { (*target).depends.is_empty() } {
50            prune(target);
51            result.push(unsafe { (*target).element });
52            return;
53        }
54
55        let mut next_node: *mut Node = core::ptr::null_mut();
56
57        // Find the first non-terminator node whose (filtered) depends are empty.
58        for i in 0..q.len() {
59            let candidate = q[i];
60            if is_block_terminator(unsafe { &*(*candidate).element }) {
61                continue;
62            }
63
64            LUAU_ASSERT!(all_arcs.contains_key(&candidate));
65            let arcs = &all_arcs[&candidate];
66
67            if arcs.depends.is_empty() {
68                next_node = candidate;
69                q.remove(i);
70                break;
71            }
72        }
73
74        // if (!nextNode) { nextNode = std::move(Q.front()); Q.pop_front(); }
75        if next_node.is_null() {
76            // We've hit a cycle or a terminator. Pick an arbitrary node.
77            next_node = *q.front().unwrap();
78            q.pop_front();
79        }
80
81        // Remove nextNode from the filtered arcs of its neighbours.
82        let provides: alloc::vec::Vec<*mut Node> =
83            unsafe { (*next_node).provides.iter().copied().collect() };
84        let depends: alloc::vec::Vec<*mut Node> =
85            unsafe { (*next_node).depends.iter().copied().collect() };
86
87        for node in provides {
88            if let Some(arcs) = all_arcs.get_mut(&node) {
89                let removed = arcs.depends.remove(&next_node);
90                LUAU_ASSERT!(removed);
91            }
92        }
93
94        for node in depends {
95            if let Some(arcs) = all_arcs.get_mut(&node) {
96                let removed = arcs.provides.remove(&next_node);
97                LUAU_ASSERT!(removed);
98            }
99        }
100
101        prune(next_node);
102        result.push(unsafe { (*next_node).element });
103    }
104
105    // if (target) { prune(target); result.push_back(target->element); }
106    if !target.is_null() {
107        prune(target);
108        result.push(unsafe { (*target).element });
109    }
110}