1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
//! Inheritance Graph Solver
//!
//! Manages the nominal inheritance relationships between classes and interfaces.
//! Provides O(1) subtype checks via lazy transitive closure and handles
//! Method Resolution Order (MRO) for member lookup.
use fixedbitset::FixedBitSet;
use rustc_hash::{FxHashMap, FxHashSet};
use std::cell::RefCell;
use std::collections::VecDeque;
use tsz_binder::SymbolId;
/// Represents a node in the inheritance graph.
#[derive(Debug, Clone, Default)]
struct ClassNode {
/// Direct parents (extends and implements)
parents: Vec<SymbolId>,
/// Children (for invalidation/reverse lookup)
children: Vec<SymbolId>,
/// Cached transitive closure (all ancestors)
/// If None, it needs to be computed.
ancestors_bitset: Option<FixedBitSet>,
/// Cached Method Resolution Order (linearized ancestors)
mro: Option<Vec<SymbolId>>,
}
#[derive(Debug)]
pub struct InheritanceGraph {
/// Map from `SymbolId` to graph node data
nodes: RefCell<FxHashMap<SymbolId, ClassNode>>,
/// Maximum `SymbolId` seen so far (for `BitSet` sizing)
max_symbol_id: RefCell<usize>,
}
impl Default for InheritanceGraph {
fn default() -> Self {
Self::new()
}
}
impl InheritanceGraph {
pub fn new() -> Self {
Self {
nodes: RefCell::new(FxHashMap::default()),
max_symbol_id: RefCell::new(0),
}
}
/// Register a class or interface and its direct parents.
///
/// # Arguments
/// * `child` - The `SymbolId` of the class/interface being defined
/// * `parents` - List of `SymbolIds` this type extends or implements
pub fn add_inheritance(&self, child: SymbolId, parents: &[SymbolId]) {
let mut nodes = self.nodes.borrow_mut();
let mut max_id = self.max_symbol_id.borrow_mut();
// Update max ID for bitset sizing
*max_id = (*max_id).max(child.0 as usize);
for &p in parents {
*max_id = (*max_id).max(p.0 as usize);
}
// Register child
let child_node = nodes.entry(child).or_default();
// Check if edges actually changed to avoid invalidating cache unnecessarily
if child_node.parents == parents {
return;
}
child_node.parents = parents.to_vec();
// Invalidate caches
child_node.ancestors_bitset = None;
child_node.mro = None;
// Register reverse edges (for future invalidation logic)
for &parent in parents {
let parent_node = nodes.entry(parent).or_default();
if !parent_node.children.contains(&child) {
parent_node.children.push(child);
}
}
}
/// Checks if `child` is a subtype of `ancestor` nominally.
///
/// This is an O(1) operation after the first lazy computation.
/// Returns `true` if `child` extends or implements `ancestor` (transitively).
pub fn is_derived_from(&self, child: SymbolId, ancestor: SymbolId) -> bool {
if child == ancestor {
return true;
}
// Fast path: check if nodes exist
let nodes = self.nodes.borrow();
if !nodes.contains_key(&child) || !nodes.contains_key(&ancestor) {
return false;
}
drop(nodes); // Release borrow for compute
self.ensure_transitive_closure(child);
let nodes = self.nodes.borrow();
if let Some(node) = nodes.get(&child)
&& let Some(bits) = &node.ancestors_bitset
{
return bits.contains(ancestor.0 as usize);
}
false
}
/// Gets the Method Resolution Order (MRO) for a symbol.
///
/// Returns a list of `SymbolIds` in the order they should be searched for members.
/// Implements a depth-first, left-to-right traversal (standard for TS/JS).
pub fn get_resolution_order(&self, symbol_id: SymbolId) -> Vec<SymbolId> {
self.ensure_mro(symbol_id);
let nodes = self.nodes.borrow();
if let Some(node) = nodes.get(&symbol_id)
&& let Some(mro) = &node.mro
{
return mro.clone();
}
vec![symbol_id] // Fallback: just the symbol itself
}
/// Finds the Least Upper Bound (common ancestor) of two symbols.
///
/// Returns the most specific symbol that both A and B inherit from.
/// In cases of multiple inheritance (interfaces), this might return one of several valid candidates.
pub fn find_common_ancestor(&self, a: SymbolId, b: SymbolId) -> Option<SymbolId> {
if self.is_derived_from(a, b) {
return Some(b);
}
if self.is_derived_from(b, a) {
return Some(a);
}
self.ensure_transitive_closure(a);
self.ensure_transitive_closure(b);
let nodes = self.nodes.borrow();
let node_a = nodes.get(&a)?;
let node_b = nodes.get(&b)?;
let bits_a = node_a.ancestors_bitset.as_ref()?;
let bits_b = node_b.ancestors_bitset.as_ref()?;
// Intersection of ancestors
let mut common = bits_a.clone();
common.intersect_with(bits_b);
// We want the "lowest" (most specific) ancestor.
// In a topological sort, this is usually the one with the longest path or
// appearing earliest in MRO.
// Simplified approach: Iterate A's MRO and return the first one present in B's ancestors.
drop(nodes); // Release for MRO check
let mro_a = self.get_resolution_order(a);
mro_a
.into_iter()
.find(|&ancestor| self.is_derived_from(b, ancestor))
}
/// Detects if adding an edge would create a cycle.
pub fn detects_cycle(&self, child: SymbolId, parent: SymbolId) -> bool {
// If parent is already derived from child, adding child->parent creates a cycle
self.is_derived_from(parent, child)
}
/// Get the direct parents of a symbol (for cycle detection).
pub fn get_parents(&self, symbol_id: SymbolId) -> Vec<SymbolId> {
let nodes = self.nodes.borrow();
if let Some(node) = nodes.get(&symbol_id) {
node.parents.clone()
} else {
Vec::new()
}
}
// =========================================================================
// Internal Lazy Computation Methods
// =========================================================================
/// Lazily computes the transitive closure (ancestor bitset) for a node.
fn ensure_transitive_closure(&self, symbol_id: SymbolId) {
let mut nodes = self.nodes.borrow_mut();
// If already computed, return
if let Some(node) = nodes.get(&symbol_id) {
if node.ancestors_bitset.is_some() {
return;
}
} else {
return; // Node doesn't exist
}
// Stack for DFS
let max_len = *self.max_symbol_id.borrow() + 1;
// Cycle detection set for this traversal
let mut path = FxHashSet::default();
self.compute_closure_recursive(symbol_id, &mut nodes, &mut path, max_len);
}
#[allow(clippy::only_used_in_recursion)]
fn compute_closure_recursive(
&self,
current: SymbolId,
nodes: &mut FxHashMap<SymbolId, ClassNode>,
path: &mut FxHashSet<SymbolId>,
bitset_len: usize,
) {
if path.contains(¤t) {
// Cycle detected, stop recursion here.
// In a real compiler, we might emit a diagnostic here,
// but the solver just wants to avoid infinite loops.
return;
}
// If already computed, we are good
if let Some(node) = nodes.get(¤t)
&& node.ancestors_bitset.is_some()
{
return;
}
path.insert(current);
// Clone parents to avoid borrowing issues during recursion
let parents = if let Some(node) = nodes.get(¤t) {
node.parents.clone()
} else {
Vec::new()
};
let mut my_bits = FixedBitSet::with_capacity(bitset_len);
for parent in parents {
// Ensure parent is computed
self.compute_closure_recursive(parent, nodes, path, bitset_len);
// Add parent itself
my_bits.insert(parent.0 as usize);
// Add parent's ancestors
if let Some(parent_node) = nodes.get(&parent)
&& let Some(parent_bits) = &parent_node.ancestors_bitset
{
my_bits.union_with(parent_bits);
}
}
// Save result
if let Some(node) = nodes.get_mut(¤t) {
node.ancestors_bitset = Some(my_bits);
}
path.remove(¤t);
}
/// Lazily computes the MRO for a node.
fn ensure_mro(&self, symbol_id: SymbolId) {
let mut nodes = self.nodes.borrow_mut();
if let Some(node) = nodes.get(&symbol_id) {
if node.mro.is_some() {
return;
}
} else {
return;
}
// Standard Depth-First Left-to-Right traversal for TypeScript
// (Note: Python uses C3, but TS is simpler)
let mut mro = Vec::new();
let mut visited = FxHashSet::default();
let mut queue = VecDeque::new();
queue.push_back(symbol_id);
while let Some(current) = queue.pop_front() {
if !visited.insert(current) {
continue;
}
mro.push(current);
if let Some(node) = nodes.get(¤t) {
// Add parents to queue
// For class extends A implements B, C -> A, B, C
for parent in &node.parents {
queue.push_back(*parent);
}
}
}
if let Some(node) = nodes.get_mut(&symbol_id) {
node.mro = Some(mro);
}
}
/// Clear all cached data (useful for testing or rebuilding)
pub fn clear(&self) {
self.nodes.borrow_mut().clear();
*self.max_symbol_id.borrow_mut() = 0;
}
/// Get the number of nodes in the graph
pub fn len(&self) -> usize {
self.nodes.borrow().len()
}
/// Check if the graph is empty
pub fn is_empty(&self) -> bool {
self.nodes.borrow().is_empty()
}
}
#[cfg(test)]
#[path = "../tests/inheritance_tests.rs"]
mod tests;