Skip to main content

lisette_semantics/checker/
scopes.rs

1use rustc_hash::{FxHashMap as HashMap, FxHashSet as HashSet};
2use std::cell::Cell;
3use syntax::ast::BindingId;
4use syntax::ast::Span;
5use syntax::types::Type;
6
7#[derive(Debug, Clone, Default)]
8pub struct DepthCounter(Cell<usize>);
9
10impl DepthCounter {
11    pub fn new() -> Self {
12        Self(Cell::new(0))
13    }
14    pub fn with_value(n: usize) -> Self {
15        Self(Cell::new(n))
16    }
17    pub fn get(&self) -> usize {
18        self.0.get()
19    }
20    pub fn increment(&self) {
21        self.0.set(self.0.get() + 1);
22    }
23    pub fn decrement(&self) {
24        self.0.set(self.0.get().saturating_sub(1));
25    }
26    pub fn is_active(&self) -> bool {
27        self.0.get() > 0
28    }
29    pub fn reset(&self) -> usize {
30        let prev = self.0.get();
31        self.0.set(0);
32        prev
33    }
34    pub fn restore(&self, depth: usize) {
35        self.0.set(depth);
36    }
37}
38
39#[derive(Debug, Clone, Copy, PartialEq, Eq, Default)]
40pub enum UseContext {
41    #[default]
42    Statement,
43    Value,
44    Callee,
45}
46
47#[derive(Debug, Clone, Copy, PartialEq, Eq)]
48pub enum CarrierKind {
49    Result,
50    Option,
51}
52
53#[derive(Debug, Clone)]
54pub struct TryBlockContext {
55    pub ok_ty: Type,
56    pub err_ty: Type,
57    pub carrier: Cell<Option<CarrierKind>>,
58    pub has_question_mark: Cell<bool>,
59    pub try_span: Span,
60    pub loop_depth: DepthCounter,
61}
62
63#[derive(Debug, Clone)]
64pub struct RecoverBlockContext {
65    pub inner_ty: Type,
66    pub recover_span: Span,
67    pub loop_depth: DepthCounter,
68}
69
70#[derive(Debug, Clone)]
71pub struct Scope {
72    /// variable name -> type
73    pub values: HashMap<String, Type>,
74    pub mutables: Option<HashSet<String>>,
75    pub type_params: Option<HashMap<String, usize>>,
76    pub trait_bounds: Option<HashMap<String, Vec<Type>>>,
77    pub fn_return_type: Option<Type>,
78    pub try_block_context: Option<TryBlockContext>,
79    pub recover_block_context: Option<RecoverBlockContext>,
80    pub loop_break_type: Option<Type>,
81    pub loop_depth: DepthCounter,
82    pub defer_block_depth: DepthCounter,
83    pub negation_depth: DepthCounter,
84    pub use_context: Cell<UseContext>,
85    /// variable name -> binding ID (for linting)
86    pub name_to_binding: HashMap<String, BindingId>,
87}
88
89impl Default for Scope {
90    fn default() -> Self {
91        Self::new()
92    }
93}
94
95impl Scope {
96    pub fn new() -> Self {
97        Scope {
98            values: HashMap::default(),
99            mutables: None,
100            type_params: None,
101            trait_bounds: None,
102            fn_return_type: None,
103            try_block_context: None,
104            recover_block_context: None,
105            loop_break_type: None,
106            loop_depth: DepthCounter::new(),
107            defer_block_depth: DepthCounter::new(),
108            negation_depth: DepthCounter::new(),
109            use_context: Cell::new(UseContext::Statement),
110            name_to_binding: HashMap::default(),
111        }
112    }
113}
114
115pub struct Scopes {
116    stack: Vec<Scope>,
117}
118
119impl Default for Scopes {
120    fn default() -> Self {
121        Self::new()
122    }
123}
124
125impl Scopes {
126    pub fn new() -> Self {
127        Scopes {
128            stack: vec![Scope::new()],
129        }
130    }
131
132    pub fn current(&self) -> &Scope {
133        self.stack.last().expect("scope stack must not be empty")
134    }
135
136    pub fn current_mut(&mut self) -> &mut Scope {
137        self.stack
138            .last_mut()
139            .expect("scope stack must not be empty")
140    }
141
142    pub fn push(&mut self) {
143        let current = self.current();
144        let loop_depth = current.loop_depth.get();
145        let defer_block_depth = current.defer_block_depth.get();
146        let negation_depth = current.negation_depth.get();
147        let use_context = current.use_context.get();
148        let loop_break_type = current.loop_break_type.clone();
149        self.stack.push(Scope {
150            values: HashMap::default(),
151            mutables: None,
152            type_params: None,
153            trait_bounds: None,
154            fn_return_type: None,
155            try_block_context: None,
156            recover_block_context: None,
157            loop_break_type,
158            loop_depth: DepthCounter::with_value(loop_depth),
159            defer_block_depth: DepthCounter::with_value(defer_block_depth),
160            negation_depth: DepthCounter::with_value(negation_depth),
161            use_context: Cell::new(use_context),
162            name_to_binding: HashMap::default(),
163        });
164    }
165
166    pub fn pop(&mut self) {
167        if self.stack.len() > 1 {
168            self.stack.pop();
169        }
170    }
171
172    pub fn reset(&mut self) {
173        self.stack.clear();
174        self.stack.push(Scope::new());
175    }
176
177    /// Look up a value by walking the scope stack from top to bottom.
178    pub fn lookup_value(&self, name: &str) -> Option<&Type> {
179        for scope in self.stack.iter().rev() {
180            if let Some(ty) = scope.values.get(name) {
181                return Some(ty);
182            }
183        }
184        None
185    }
186
187    /// Check if a variable is marked mutable in any enclosing scope.
188    pub fn lookup_mutable(&self, name: &str) -> bool {
189        self.stack
190            .iter()
191            .rev()
192            .any(|s| s.mutables.as_ref().is_some_and(|m| m.contains(name)))
193    }
194
195    /// Look up a binding ID by walking the scope stack from top to bottom.
196    pub fn lookup_binding_id(&self, name: &str) -> Option<BindingId> {
197        for scope in self.stack.iter().rev() {
198            if let Some(id) = scope.name_to_binding.get(name) {
199                return Some(*id);
200            }
201        }
202        None
203    }
204
205    /// Look up a type parameter by walking the scope stack from top to bottom.
206    pub fn lookup_type_param(&self, name: &str) -> Option<usize> {
207        for scope in self.stack.iter().rev() {
208            if let Some(idx) = scope.type_params.as_ref().and_then(|tp| tp.get(name)) {
209                return Some(*idx);
210            }
211        }
212        None
213    }
214
215    /// Look up the enclosing function's return type.
216    pub fn lookup_fn_return_type(&self) -> Option<&Type> {
217        for scope in self.stack.iter().rev() {
218            if let Some(ref ty) = scope.fn_return_type {
219                return Some(ty);
220            }
221        }
222        None
223    }
224
225    /// Look up the enclosing try block context, stopping at function boundaries.
226    pub fn lookup_try_block_context(&self) -> Option<&TryBlockContext> {
227        for scope in self.stack.iter().rev() {
228            if scope.try_block_context.is_some() {
229                return scope.try_block_context.as_ref();
230            }
231            if scope.fn_return_type.is_some() {
232                return None;
233            }
234        }
235        None
236    }
237
238    /// Look up the enclosing recover block context, stopping at function boundaries.
239    pub fn lookup_recover_block_context(&self) -> Option<&RecoverBlockContext> {
240        for scope in self.stack.iter().rev() {
241            if scope.recover_block_context.is_some() {
242                return scope.recover_block_context.as_ref();
243            }
244            if scope.fn_return_type.is_some() {
245                return None;
246            }
247        }
248        None
249    }
250
251    pub fn collect_all_value_names(&self) -> Vec<String> {
252        let mut names = Vec::new();
253        for scope in &self.stack {
254            names.extend(scope.values.keys().cloned());
255        }
256        names
257    }
258
259    pub fn collect_all_trait_bounds(&self) -> HashMap<String, Vec<Type>> {
260        let mut all_bounds = HashMap::default();
261        // Walk from bottom to top so inner scopes override outer
262        for scope in &self.stack {
263            if let Some(ref bounds) = scope.trait_bounds {
264                for (key, value) in bounds {
265                    all_bounds.insert(key.clone(), value.clone());
266                }
267            }
268        }
269        all_bounds
270    }
271
272    pub fn increment_loop_depth(&self) {
273        self.current().loop_depth.increment();
274    }
275
276    pub fn decrement_loop_depth(&self) {
277        self.current().loop_depth.decrement();
278    }
279
280    pub fn is_inside_loop(&self) -> bool {
281        self.current().loop_depth.is_active()
282    }
283
284    pub fn set_loop_break_type(&mut self, ty: Type) {
285        self.current_mut().loop_break_type = Some(ty);
286    }
287
288    pub fn clear_loop_break_type(&mut self) {
289        self.current_mut().loop_break_type = None;
290    }
291
292    pub fn loop_break_type(&self) -> Option<&Type> {
293        self.current().loop_break_type.as_ref()
294    }
295
296    pub fn increment_defer_block_depth(&self) {
297        self.current().defer_block_depth.increment();
298    }
299
300    pub fn decrement_defer_block_depth(&self) {
301        self.current().defer_block_depth.decrement();
302    }
303
304    pub fn is_inside_defer_block(&self) -> bool {
305        self.current().defer_block_depth.is_active()
306    }
307
308    pub fn defer_block_loop_depth(&self) -> usize {
309        self.current().loop_depth.get()
310    }
311
312    pub fn increment_negation_depth(&self) {
313        self.current().negation_depth.increment();
314    }
315
316    pub fn decrement_negation_depth(&self) {
317        self.current().negation_depth.decrement();
318    }
319
320    pub fn is_inside_negation(&self) -> bool {
321        self.current().negation_depth.is_active()
322    }
323
324    pub fn reset_loop_depth(&self) -> usize {
325        self.current().loop_depth.reset()
326    }
327
328    pub fn restore_loop_depth(&self, depth: usize) {
329        self.current().loop_depth.restore(depth);
330    }
331
332    pub fn set_value_context(&self) -> UseContext {
333        let prev = self.current().use_context.get();
334        self.current().use_context.set(UseContext::Value);
335        prev
336    }
337
338    pub fn set_statement_context(&self) -> UseContext {
339        let prev = self.current().use_context.get();
340        self.current().use_context.set(UseContext::Statement);
341        prev
342    }
343
344    pub fn restore_use_context(&self, ctx: UseContext) {
345        self.current().use_context.set(ctx);
346    }
347
348    pub fn is_value_context(&self) -> bool {
349        self.current().use_context.get() == UseContext::Value
350    }
351
352    pub fn set_callee_context(&self) -> UseContext {
353        let prev = self.current().use_context.get();
354        self.current().use_context.set(UseContext::Callee);
355        prev
356    }
357
358    pub fn is_callee_context(&self) -> bool {
359        self.current().use_context.get() == UseContext::Callee
360    }
361}