plotnik_compiler/analyze/
dependencies.rs

1//! Dependency analysis for definitions.
2//!
3//! Computes the dependency graph of definitions and identifies Strongly Connected
4//! Components (SCCs). The computed SCCs are exposed in reverse topological order
5//! (leaves first), which is useful for passes that need to process dependencies
6//! before dependents (like type inference).
7
8use std::collections::{HashMap, HashSet};
9
10use indexmap::{IndexMap, IndexSet};
11use plotnik_core::{Interner, Symbol};
12
13use super::symbol_table::SymbolTable;
14use super::type_check::DefId;
15use crate::parser::{Expr, Ref};
16
17/// Result of dependency analysis.
18#[derive(Clone, Debug, Default)]
19pub struct DependencyAnalysis {
20    /// Strongly connected components in reverse topological order.
21    ///
22    /// - `sccs[0]` has no dependencies (or depends only on things not in this list).
23    /// - `sccs.last()` depends on everything else.
24    /// - Definitions within an SCC are mutually recursive.
25    /// - Every definition in the symbol table appears exactly once.
26    pub sccs: Vec<Vec<String>>,
27
28    /// Maps definition name (Symbol) to its DefId.
29    name_to_def: HashMap<Symbol, DefId>,
30
31    /// Maps DefId to definition name Symbol (indexed by DefId).
32    def_names: Vec<Symbol>,
33
34    /// Set of recursive definition names.
35    ///
36    /// A definition is recursive if it's in an SCC with >1 member,
37    /// or it's a single-member SCC that references itself.
38    recursive_defs: HashSet<String>,
39}
40
41impl DependencyAnalysis {
42    /// Get the DefId for a definition by Symbol.
43    pub fn def_id_by_symbol(&self, sym: Symbol) -> Option<DefId> {
44        self.name_to_def.get(&sym).copied()
45    }
46
47    /// Get the DefId for a definition name (requires interner for lookup).
48    pub fn def_id(&self, interner: &Interner, name: &str) -> Option<DefId> {
49        // Linear scan - only used during analysis, not hot path
50        for (&sym, &def_id) in &self.name_to_def {
51            if interner.resolve(sym) == name {
52                return Some(def_id);
53            }
54        }
55        None
56    }
57
58    /// Get the name Symbol for a DefId.
59    pub fn def_name_sym(&self, id: DefId) -> Symbol {
60        self.def_names[id.index()]
61    }
62
63    /// Get the name string for a DefId.
64    pub fn def_name<'a>(&self, interner: &'a Interner, id: DefId) -> &'a str {
65        interner.resolve(self.def_names[id.index()])
66    }
67
68    /// Number of definitions.
69    pub fn def_count(&self) -> usize {
70        self.def_names.len()
71    }
72
73    /// Get the def_names slice (for seeding TypeContext).
74    pub fn def_names(&self) -> &[Symbol] {
75        &self.def_names
76    }
77
78    /// Get the name_to_def map (for seeding TypeContext).
79    pub fn name_to_def(&self) -> &HashMap<Symbol, DefId> {
80        &self.name_to_def
81    }
82
83    /// Returns true if this definition is recursive.
84    ///
85    /// A definition is recursive if it's part of a mutual recursion group (SCC > 1),
86    /// or it's a single definition that references itself.
87    pub fn is_recursive(&self, name: &str) -> bool {
88        self.recursive_defs.contains(name)
89    }
90}
91
92/// Analyze dependencies between definitions.
93///
94/// Returns the SCCs in reverse topological order, with DefId mappings.
95/// The interner is used to intern definition names as Symbols.
96pub fn analyze_dependencies(
97    symbol_table: &SymbolTable,
98    interner: &mut Interner,
99) -> DependencyAnalysis {
100    let sccs = SccFinder::find(symbol_table);
101
102    // Assign DefIds in SCC order (leaves first, so dependencies get lower IDs)
103    let mut name_to_def = HashMap::new();
104    let mut def_names = Vec::new();
105    let mut recursive_defs = HashSet::new();
106
107    for scc in &sccs {
108        // Mark recursive definitions
109        if scc.len() > 1 {
110            // Mutual recursion: all members are recursive
111            recursive_defs.extend(scc.iter().cloned());
112        } else if let Some(name) = scc.first()
113            && let Some(body) = symbol_table.get(name)
114            && super::refs::contains_ref(body, name)
115        {
116            recursive_defs.insert(name.clone());
117        }
118
119        for name in scc {
120            let sym = interner.intern(name);
121            let def_id = DefId::from_raw(def_names.len() as u32);
122            name_to_def.insert(sym, def_id);
123            def_names.push(sym);
124        }
125    }
126
127    DependencyAnalysis {
128        sccs,
129        name_to_def,
130        def_names,
131        recursive_defs,
132    }
133}
134
135struct SccFinder<'a> {
136    symbol_table: &'a SymbolTable,
137    index: usize,
138    stack: Vec<&'a str>,
139    on_stack: IndexSet<&'a str>,
140    indices: IndexMap<&'a str, usize>,
141    lowlinks: IndexMap<&'a str, usize>,
142    sccs: Vec<Vec<&'a str>>,
143}
144
145impl<'a> SccFinder<'a> {
146    fn find(symbol_table: &'a SymbolTable) -> Vec<Vec<String>> {
147        let mut finder = Self {
148            symbol_table,
149            index: 0,
150            stack: Vec::new(),
151            on_stack: IndexSet::new(),
152            indices: IndexMap::new(),
153            lowlinks: IndexMap::new(),
154            sccs: Vec::new(),
155        };
156
157        for name in symbol_table.keys() {
158            if !finder.indices.contains_key(name as &str) {
159                finder.strongconnect(name);
160            }
161        }
162
163        finder
164            .sccs
165            .into_iter()
166            .map(|scc| scc.into_iter().map(String::from).collect())
167            .collect()
168    }
169
170    fn strongconnect(&mut self, name: &'a str) {
171        self.indices.insert(name, self.index);
172        self.lowlinks.insert(name, self.index);
173        self.index += 1;
174        self.stack.push(name);
175        self.on_stack.insert(name);
176
177        if let Some(body) = self.symbol_table.get(name) {
178            let refs = collect_refs(body, self.symbol_table);
179            for ref_name in refs {
180                if !self.indices.contains_key(ref_name) {
181                    self.strongconnect(ref_name);
182                    let ref_lowlink = self.lowlinks[ref_name];
183                    let my_lowlink = self.lowlinks.get_mut(name).unwrap();
184                    *my_lowlink = (*my_lowlink).min(ref_lowlink);
185                } else if self.on_stack.contains(ref_name) {
186                    let ref_index = self.indices[ref_name];
187                    let my_lowlink = self.lowlinks.get_mut(name).unwrap();
188                    *my_lowlink = (*my_lowlink).min(ref_index);
189                }
190            }
191        }
192
193        if self.lowlinks[name] == self.indices[name] {
194            let mut scc = Vec::new();
195            loop {
196                let w = self.stack.pop().unwrap();
197                self.on_stack.swap_remove(&w);
198                let done = w == name;
199                scc.push(w);
200                if done {
201                    break;
202                }
203            }
204            self.sccs.push(scc);
205        }
206    }
207}
208
209/// Collect references to definitions within the symbol table.
210///
211/// Returns only refs that point to defined names (filters out node type references).
212pub(super) fn collect_refs<'a>(expr: &Expr, symbol_table: &'a SymbolTable) -> IndexSet<&'a str> {
213    let mut refs = IndexSet::new();
214    for descendant in expr.as_cst().descendants() {
215        let Some(r) = Ref::cast(descendant) else {
216            continue;
217        };
218        let Some(name_tok) = r.name() else { continue };
219        let Some(key) = symbol_table.keys().find(|&k| k == name_tok.text()) else {
220            continue;
221        };
222        refs.insert(key);
223    }
224    refs
225}