Skip to main content

lisette_semantics/checker/
scopes.rs

1use rustc_hash::{FxHashMap as HashMap, FxHashSet as HashSet};
2use std::cell::Cell;
3use syntax::ast::BindingId;
4use syntax::ast::Span;
5use syntax::types::{Symbol, Type};
6
7#[derive(Debug, Clone, Default)]
8pub struct DepthCounter(Cell<usize>);
9
10impl DepthCounter {
11    pub fn new() -> Self {
12        Self(Cell::new(0))
13    }
14    pub fn with_value(n: usize) -> Self {
15        Self(Cell::new(n))
16    }
17    pub fn get(&self) -> usize {
18        self.0.get()
19    }
20    pub fn increment(&self) {
21        self.0.set(self.0.get() + 1);
22    }
23    pub fn decrement(&self) {
24        self.0.set(self.0.get().saturating_sub(1));
25    }
26    pub fn is_active(&self) -> bool {
27        self.0.get() > 0
28    }
29    pub fn reset(&self) -> usize {
30        let prev = self.0.get();
31        self.0.set(0);
32        prev
33    }
34    pub fn restore(&self, depth: usize) {
35        self.0.set(depth);
36    }
37}
38
39#[derive(Debug, Clone, Copy, PartialEq, Eq, Default)]
40pub enum UseContext {
41    #[default]
42    Statement,
43    Value,
44    Callee,
45    AssignmentTarget,
46}
47
48#[derive(Debug, Clone, Copy, PartialEq, Eq)]
49pub enum CarrierKind {
50    Result,
51    Option,
52}
53
54#[derive(Debug, Clone)]
55pub struct TryBlockContext {
56    pub ok_ty: Type,
57    pub err_ty: Type,
58    pub carrier: Cell<Option<CarrierKind>>,
59    pub has_question_mark: Cell<bool>,
60    pub try_span: Span,
61    pub loop_depth: DepthCounter,
62}
63
64#[derive(Debug, Clone)]
65pub struct RecoverBlockContext {
66    pub inner_ty: Type,
67    pub recover_span: Span,
68    pub loop_depth: DepthCounter,
69}
70
71#[derive(Debug, Clone)]
72pub struct Scope {
73    /// variable name -> type
74    pub values: HashMap<String, Type>,
75    pub mutables: Option<HashSet<String>>,
76    pub consts: Option<HashSet<String>>,
77    pub type_params: Option<HashMap<String, usize>>,
78    pub trait_bounds: Option<HashMap<Symbol, Vec<Type>>>,
79    pub fn_return_type: Option<Type>,
80    pub try_block_context: Option<TryBlockContext>,
81    pub recover_block_context: Option<RecoverBlockContext>,
82    pub loop_break_type: Option<Type>,
83    pub loop_depth: DepthCounter,
84    pub defer_block_depth: DepthCounter,
85    pub negation_depth: DepthCounter,
86    pub type_param_depth: DepthCounter,
87    pub use_context: Cell<UseContext>,
88    /// variable name -> binding ID (for linting)
89    pub name_to_binding: HashMap<String, BindingId>,
90}
91
92impl Default for Scope {
93    fn default() -> Self {
94        Self::new()
95    }
96}
97
98impl Scope {
99    pub fn new() -> Self {
100        Scope {
101            values: HashMap::default(),
102            mutables: None,
103            consts: None,
104            type_params: None,
105            trait_bounds: None,
106            fn_return_type: None,
107            try_block_context: None,
108            recover_block_context: None,
109            loop_break_type: None,
110            loop_depth: DepthCounter::new(),
111            defer_block_depth: DepthCounter::new(),
112            negation_depth: DepthCounter::new(),
113            type_param_depth: DepthCounter::new(),
114            use_context: Cell::new(UseContext::Statement),
115            name_to_binding: HashMap::default(),
116        }
117    }
118}
119
120pub struct Scopes {
121    stack: Vec<Scope>,
122    /// True when inferring the body of a match/select arm. Consumed by
123    /// `infer_break`/`infer_continue` to decide whether the enclosing loop
124    /// needs a Go label (since Go switch cases do not fall through).
125    in_match_arm: Cell<bool>,
126    /// One entry per enclosing loop; set to `true` when a break/continue is
127    /// encountered inside a match arm. The top is popped by the loop's
128    /// inference function and recorded on the Loop/While/For/WhileLet AST node.
129    loop_needs_label_stack: std::cell::RefCell<Vec<bool>>,
130    /// True when inferring inside a compound expression (call arg, binary
131    /// operand, etc.). Used to reject `Err(x)?`/`None?` and similar control-flow
132    /// in positions where they can never produce a value.
133    in_subexpression: Cell<bool>,
134    /// True when inferring the base of a dot-access chain. Suppresses the
135    /// record-struct-as-value error when the struct name is a type qualifier
136    /// (e.g. `lib.Point` in `lib.Point.sum`).
137    dot_access_base: Cell<bool>,
138    /// The enclosing impl block's receiver type, used to resolve `self`
139    /// parameter annotations inside the impl's methods. `None` outside impls.
140    /// Singleton because Lisette does not allow nested impl blocks.
141    impl_receiver_type: Option<Type>,
142}
143
144impl Default for Scopes {
145    fn default() -> Self {
146        Self::new()
147    }
148}
149
150impl Scopes {
151    pub fn new() -> Self {
152        Scopes {
153            stack: vec![Scope::new()],
154            in_match_arm: Cell::new(false),
155            loop_needs_label_stack: std::cell::RefCell::new(Vec::new()),
156            in_subexpression: Cell::new(false),
157            dot_access_base: Cell::new(false),
158            impl_receiver_type: None,
159        }
160    }
161
162    pub fn current(&self) -> &Scope {
163        self.stack.last().expect("scope stack must not be empty")
164    }
165
166    pub fn current_mut(&mut self) -> &mut Scope {
167        self.stack
168            .last_mut()
169            .expect("scope stack must not be empty")
170    }
171
172    pub fn push(&mut self) {
173        let current = self.current();
174        let mut scope = Scope::new();
175        scope.loop_break_type = current.loop_break_type.clone();
176        scope.loop_depth = DepthCounter::with_value(current.loop_depth.get());
177        scope.defer_block_depth = DepthCounter::with_value(current.defer_block_depth.get());
178        scope.negation_depth = DepthCounter::with_value(current.negation_depth.get());
179        scope.type_param_depth = DepthCounter::with_value(current.type_param_depth.get());
180        scope.use_context = Cell::new(current.use_context.get());
181        self.stack.push(scope);
182    }
183
184    pub fn pop(&mut self) {
185        if self.stack.len() > 1 {
186            self.stack.pop();
187        }
188    }
189
190    pub fn reset(&mut self) {
191        self.stack.clear();
192        self.stack.push(Scope::new());
193        self.in_match_arm.set(false);
194        self.loop_needs_label_stack.borrow_mut().clear();
195        self.in_subexpression.set(false);
196        self.dot_access_base.set(false);
197        self.impl_receiver_type = None;
198    }
199
200    /// Look up a value by walking the scope stack from top to bottom.
201    pub fn lookup_value(&self, name: &str) -> Option<&Type> {
202        for scope in self.stack.iter().rev() {
203            if let Some(ty) = scope.values.get(name) {
204                return Some(ty);
205            }
206        }
207        None
208    }
209
210    /// Check if a variable is marked mutable in any enclosing scope.
211    pub fn lookup_mutable(&self, name: &str) -> bool {
212        self.stack
213            .iter()
214            .rev()
215            .any(|s| s.mutables.as_ref().is_some_and(|m| m.contains(name)))
216    }
217
218    /// Whether `name` is a block-local `const` in any enclosing scope.
219    pub fn lookup_const(&self, name: &str) -> bool {
220        self.stack
221            .iter()
222            .rev()
223            .any(|s| s.consts.as_ref().is_some_and(|c| c.contains(name)))
224    }
225
226    /// Look up a binding ID by walking the scope stack from top to bottom.
227    pub fn lookup_binding_id(&self, name: &str) -> Option<BindingId> {
228        for scope in self.stack.iter().rev() {
229            if let Some(id) = scope.name_to_binding.get(name) {
230                return Some(*id);
231            }
232        }
233        None
234    }
235
236    /// Look up a type parameter by walking the scope stack from top to bottom.
237    pub fn lookup_type_param(&self, name: &str) -> Option<usize> {
238        for scope in self.stack.iter().rev() {
239            if let Some(idx) = scope.type_params.as_ref().and_then(|tp| tp.get(name)) {
240                return Some(*idx);
241            }
242        }
243        None
244    }
245
246    /// Look up the enclosing function's return type.
247    pub fn lookup_fn_return_type(&self) -> Option<&Type> {
248        for scope in self.stack.iter().rev() {
249            if let Some(ref ty) = scope.fn_return_type {
250                return Some(ty);
251            }
252        }
253        None
254    }
255
256    /// Look up the enclosing try block context, stopping at function boundaries.
257    pub fn lookup_try_block_context(&self) -> Option<&TryBlockContext> {
258        for scope in self.stack.iter().rev() {
259            if scope.try_block_context.is_some() {
260                return scope.try_block_context.as_ref();
261            }
262            if scope.fn_return_type.is_some() {
263                return None;
264            }
265        }
266        None
267    }
268
269    /// Look up the enclosing recover block context, stopping at function boundaries.
270    pub fn lookup_recover_block_context(&self) -> Option<&RecoverBlockContext> {
271        for scope in self.stack.iter().rev() {
272            if scope.recover_block_context.is_some() {
273                return scope.recover_block_context.as_ref();
274            }
275            if scope.fn_return_type.is_some() {
276                return None;
277            }
278        }
279        None
280    }
281
282    pub fn collect_all_value_names(&self) -> Vec<String> {
283        let mut names = Vec::new();
284        for scope in &self.stack {
285            names.extend(scope.values.keys().cloned());
286        }
287        names
288    }
289
290    pub fn collect_all_trait_bounds(&self) -> HashMap<Symbol, Vec<Type>> {
291        let mut all_bounds = HashMap::default();
292        // Walk from bottom to top so inner scopes override outer
293        for scope in &self.stack {
294            if let Some(ref bounds) = scope.trait_bounds {
295                for (key, value) in bounds {
296                    all_bounds.insert(key.clone(), value.clone());
297                }
298            }
299        }
300        all_bounds
301    }
302
303    pub fn for_each_bound_on_param<F: FnMut(&Type)>(&self, param_name: &str, mut visit: F) {
304        for scope in self.stack.iter().rev() {
305            let introduces = scope
306                .type_params
307                .as_ref()
308                .is_some_and(|tp| tp.contains_key(param_name));
309            if !introduces {
310                continue;
311            }
312            if let Some(ref bounds) = scope.trait_bounds {
313                for (key, types) in bounds {
314                    if key.last_segment() == param_name {
315                        for ty in types {
316                            visit(ty);
317                        }
318                    }
319                }
320            }
321            return;
322        }
323    }
324
325    pub fn increment_loop_depth(&self) {
326        self.current().loop_depth.increment();
327    }
328
329    pub fn decrement_loop_depth(&self) {
330        self.current().loop_depth.decrement();
331    }
332
333    pub fn is_inside_loop(&self) -> bool {
334        self.current().loop_depth.is_active()
335    }
336
337    pub fn set_loop_break_type(&mut self, ty: Type) {
338        self.current_mut().loop_break_type = Some(ty);
339    }
340
341    pub fn clear_loop_break_type(&mut self) {
342        self.current_mut().loop_break_type = None;
343    }
344
345    pub fn loop_break_type(&self) -> Option<&Type> {
346        self.current().loop_break_type.as_ref()
347    }
348
349    pub fn increment_defer_block_depth(&self) {
350        self.current().defer_block_depth.increment();
351    }
352
353    pub fn decrement_defer_block_depth(&self) {
354        self.current().defer_block_depth.decrement();
355    }
356
357    pub fn is_inside_defer_block(&self) -> bool {
358        self.current().defer_block_depth.is_active()
359    }
360
361    pub fn defer_block_loop_depth(&self) -> usize {
362        self.current().loop_depth.get()
363    }
364
365    pub fn increment_negation_depth(&self) {
366        self.current().negation_depth.increment();
367    }
368
369    pub fn decrement_negation_depth(&self) {
370        self.current().negation_depth.decrement();
371    }
372
373    pub fn is_inside_negation(&self) -> bool {
374        self.current().negation_depth.is_active()
375    }
376
377    pub fn reset_loop_depth(&self) -> usize {
378        self.current().loop_depth.reset()
379    }
380
381    pub fn restore_loop_depth(&self, depth: usize) {
382        self.current().loop_depth.restore(depth);
383    }
384
385    pub fn set_value_context(&self) -> UseContext {
386        let prev = self.current().use_context.get();
387        self.current().use_context.set(UseContext::Value);
388        prev
389    }
390
391    pub fn set_statement_context(&self) -> UseContext {
392        let prev = self.current().use_context.get();
393        self.current().use_context.set(UseContext::Statement);
394        prev
395    }
396
397    pub fn restore_use_context(&self, ctx: UseContext) {
398        self.current().use_context.set(ctx);
399    }
400
401    pub fn is_value_context(&self) -> bool {
402        self.current().use_context.get() == UseContext::Value
403    }
404
405    pub fn set_callee_context(&self) -> UseContext {
406        let prev = self.current().use_context.get();
407        self.current().use_context.set(UseContext::Callee);
408        prev
409    }
410
411    pub fn is_callee_context(&self) -> bool {
412        self.current().use_context.get() == UseContext::Callee
413    }
414
415    pub fn set_assignment_target_context(&self) -> UseContext {
416        let prev = self.current().use_context.get();
417        self.current().use_context.set(UseContext::AssignmentTarget);
418        prev
419    }
420
421    pub fn is_assignment_target_context(&self) -> bool {
422        self.current().use_context.get() == UseContext::AssignmentTarget
423    }
424
425    pub fn is_in_match_arm(&self) -> bool {
426        self.in_match_arm.get()
427    }
428
429    pub fn set_in_match_arm(&self, value: bool) -> bool {
430        self.in_match_arm.replace(value)
431    }
432
433    pub fn push_loop_needs_label(&self) {
434        self.loop_needs_label_stack.borrow_mut().push(false);
435    }
436
437    pub fn pop_loop_needs_label(&self) -> bool {
438        self.loop_needs_label_stack
439            .borrow_mut()
440            .pop()
441            .expect("loop_needs_label_stack must not be empty when popping")
442    }
443
444    pub fn mark_current_loop_needs_label(&self) {
445        if let Some(flag) = self.loop_needs_label_stack.borrow_mut().last_mut() {
446            *flag = true;
447        }
448    }
449
450    pub fn is_in_subexpression(&self) -> bool {
451        self.in_subexpression.get()
452    }
453
454    pub fn set_in_subexpression(&self, value: bool) -> bool {
455        self.in_subexpression.replace(value)
456    }
457
458    pub fn is_dot_access_base(&self) -> bool {
459        self.dot_access_base.get()
460    }
461
462    pub fn set_dot_access_base(&self, value: bool) -> bool {
463        self.dot_access_base.replace(value)
464    }
465
466    pub fn increment_type_param_depth(&self) {
467        self.current().type_param_depth.increment();
468    }
469
470    pub fn decrement_type_param_depth(&self) {
471        self.current().type_param_depth.decrement();
472    }
473
474    pub fn is_inside_type_param(&self) -> bool {
475        self.current().type_param_depth.is_active()
476    }
477
478    pub fn set_impl_receiver_type(&mut self, ty: Option<Type>) {
479        self.impl_receiver_type = ty;
480    }
481
482    pub fn impl_receiver_type(&self) -> Option<&Type> {
483        self.impl_receiver_type.as_ref()
484    }
485}