plotnik_compiler/analyze/
dependencies.rs1use std::collections::{HashMap, HashSet};
9
10use indexmap::{IndexMap, IndexSet};
11use plotnik_core::{Interner, Symbol};
12
13use super::symbol_table::SymbolTable;
14use super::type_check::DefId;
15use crate::parser::{Expr, Ref};
16
17#[derive(Clone, Debug, Default)]
19pub struct DependencyAnalysis {
20 pub sccs: Vec<Vec<String>>,
27
28 name_to_def: HashMap<Symbol, DefId>,
30
31 def_names: Vec<Symbol>,
33
34 recursive_defs: HashSet<String>,
39}
40
41impl DependencyAnalysis {
42 pub fn def_id_by_symbol(&self, sym: Symbol) -> Option<DefId> {
44 self.name_to_def.get(&sym).copied()
45 }
46
47 pub fn def_id(&self, interner: &Interner, name: &str) -> Option<DefId> {
49 for (&sym, &def_id) in &self.name_to_def {
51 if interner.resolve(sym) == name {
52 return Some(def_id);
53 }
54 }
55 None
56 }
57
58 pub fn def_name_sym(&self, id: DefId) -> Symbol {
60 self.def_names[id.index()]
61 }
62
63 pub fn def_name<'a>(&self, interner: &'a Interner, id: DefId) -> &'a str {
65 interner.resolve(self.def_names[id.index()])
66 }
67
68 pub fn def_count(&self) -> usize {
70 self.def_names.len()
71 }
72
73 pub fn def_names(&self) -> &[Symbol] {
75 &self.def_names
76 }
77
78 pub fn name_to_def(&self) -> &HashMap<Symbol, DefId> {
80 &self.name_to_def
81 }
82
83 pub fn is_recursive(&self, name: &str) -> bool {
88 self.recursive_defs.contains(name)
89 }
90}
91
92pub fn analyze_dependencies(
97 symbol_table: &SymbolTable,
98 interner: &mut Interner,
99) -> DependencyAnalysis {
100 let sccs = SccFinder::find(symbol_table);
101
102 let mut name_to_def = HashMap::new();
104 let mut def_names = Vec::new();
105 let mut recursive_defs = HashSet::new();
106
107 for scc in &sccs {
108 if scc.len() > 1 {
110 recursive_defs.extend(scc.iter().cloned());
112 } else if let Some(name) = scc.first()
113 && let Some(body) = symbol_table.get(name)
114 && super::refs::contains_ref(body, name)
115 {
116 recursive_defs.insert(name.clone());
117 }
118
119 for name in scc {
120 let sym = interner.intern(name);
121 let def_id = DefId::from_raw(def_names.len() as u32);
122 name_to_def.insert(sym, def_id);
123 def_names.push(sym);
124 }
125 }
126
127 DependencyAnalysis {
128 sccs,
129 name_to_def,
130 def_names,
131 recursive_defs,
132 }
133}
134
135struct SccFinder<'a> {
136 symbol_table: &'a SymbolTable,
137 index: usize,
138 stack: Vec<&'a str>,
139 on_stack: IndexSet<&'a str>,
140 indices: IndexMap<&'a str, usize>,
141 lowlinks: IndexMap<&'a str, usize>,
142 sccs: Vec<Vec<&'a str>>,
143}
144
145impl<'a> SccFinder<'a> {
146 fn find(symbol_table: &'a SymbolTable) -> Vec<Vec<String>> {
147 let mut finder = Self {
148 symbol_table,
149 index: 0,
150 stack: Vec::new(),
151 on_stack: IndexSet::new(),
152 indices: IndexMap::new(),
153 lowlinks: IndexMap::new(),
154 sccs: Vec::new(),
155 };
156
157 for name in symbol_table.keys() {
158 if !finder.indices.contains_key(name as &str) {
159 finder.strongconnect(name);
160 }
161 }
162
163 finder
164 .sccs
165 .into_iter()
166 .map(|scc| scc.into_iter().map(String::from).collect())
167 .collect()
168 }
169
170 fn strongconnect(&mut self, name: &'a str) {
171 self.indices.insert(name, self.index);
172 self.lowlinks.insert(name, self.index);
173 self.index += 1;
174 self.stack.push(name);
175 self.on_stack.insert(name);
176
177 if let Some(body) = self.symbol_table.get(name) {
178 let refs = collect_refs(body, self.symbol_table);
179 for ref_name in refs {
180 if !self.indices.contains_key(ref_name) {
181 self.strongconnect(ref_name);
182 let ref_lowlink = self.lowlinks[ref_name];
183 let my_lowlink = self.lowlinks.get_mut(name).unwrap();
184 *my_lowlink = (*my_lowlink).min(ref_lowlink);
185 } else if self.on_stack.contains(ref_name) {
186 let ref_index = self.indices[ref_name];
187 let my_lowlink = self.lowlinks.get_mut(name).unwrap();
188 *my_lowlink = (*my_lowlink).min(ref_index);
189 }
190 }
191 }
192
193 if self.lowlinks[name] == self.indices[name] {
194 let mut scc = Vec::new();
195 loop {
196 let w = self.stack.pop().unwrap();
197 self.on_stack.swap_remove(&w);
198 let done = w == name;
199 scc.push(w);
200 if done {
201 break;
202 }
203 }
204 self.sccs.push(scc);
205 }
206 }
207}
208
209pub(super) fn collect_refs<'a>(expr: &Expr, symbol_table: &'a SymbolTable) -> IndexSet<&'a str> {
213 let mut refs = IndexSet::new();
214 for descendant in expr.as_cst().descendants() {
215 let Some(r) = Ref::cast(descendant) else {
216 continue;
217 };
218 let Some(name_tok) = r.name() else { continue };
219 let Some(key) = symbol_table.keys().find(|&k| k == name_tok.text()) else {
220 continue;
221 };
222 refs.insert(key);
223 }
224 refs
225}