Skip to main content

ling/borrowck/
checker.rs

1use ling_ast::Span;
2use ling_mir::ir::*;
3use ling_mir::liveness::Liveness;
4use std::collections::{HashMap, HashSet, VecDeque};
5
6#[derive(Debug, Clone, Copy, PartialEq, Eq)]
7enum LocalState {
8    Initialized,
9    Moved,
10    Dead,
11}
12
13#[derive(Debug, Clone)]
14struct FlowState {
15    locals: Vec<LocalState>,
16    borrows: Vec<(usize, bool)>,
17}
18
19impl FlowState {
20    fn new(num_locals: usize) -> Self {
21        Self {
22            locals: vec![LocalState::Dead; num_locals],
23            borrows: vec![(0, false); num_locals],
24        }
25    }
26
27    fn get(&self, local: Local) -> LocalState {
28        self.locals
29            .get(local.0)
30            .copied()
31            .unwrap_or(LocalState::Dead)
32    }
33
34    fn set(&mut self, local: Local, state: LocalState) -> Result<(), String> {
35        if local.0 < self.locals.len() {
36            let current_borrows = self.borrows[local.0];
37            if current_borrows.0 > 0 || current_borrows.1 {
38                if state != LocalState::Initialized {
39                    return Err("cannot move or drop variable while it is borrowed".to_string());
40                } else {
41                    return Err("cannot reassign variable while it is borrowed".to_string());
42                }
43            }
44            self.locals[local.0] = state;
45            if state != LocalState::Initialized {
46                self.borrows[local.0] = (0, false);
47            }
48        }
49        Ok(())
50    }
51
52    fn join(&mut self, other: &FlowState) -> bool {
53        let mut changed = false;
54        let len = self.locals.len().max(other.locals.len());
55        if self.locals.len() < len {
56            self.locals.resize(len, LocalState::Dead);
57            self.borrows.resize(len, (0, false));
58        }
59        for (i, other_state) in other.locals.iter().enumerate() {
60            let merged = merge_states(self.locals[i], *other_state);
61            if merged != self.locals[i] {
62                self.locals[i] = merged;
63                changed = true;
64            }
65        }
66        for i in 0..self.borrows.len().min(other.borrows.len()) {
67            let new_count = self.borrows[i].0.max(other.borrows[i].0);
68            let new_mut = self.borrows[i].1 || other.borrows[i].1;
69            if self.borrows[i].0 != new_count || self.borrows[i].1 != new_mut {
70                self.borrows[i] = (new_count, new_mut);
71                changed = true;
72            }
73        }
74        changed
75    }
76}
77
78fn merge_states(a: LocalState, b: LocalState) -> LocalState {
79    match (a, b) {
80        (LocalState::Moved, _) | (_, LocalState::Moved) => LocalState::Moved,
81        (LocalState::Initialized, _) | (_, LocalState::Initialized) => LocalState::Initialized,
82        _ => LocalState::Dead,
83    }
84}
85
86pub struct BorrowChecker<'a> {
87    pub func: &'a MirFunction,
88    pub errors: Vec<String>,
89    pub liveness: Liveness,
90    pub provenance: HashMap<Local, Local>,
91}
92
93impl<'a> BorrowChecker<'a> {
94    pub fn new(func: &'a MirFunction) -> Self {
95        let liveness = Liveness::compute(func);
96        Self {
97            func,
98            liveness,
99            errors: Vec::new(),
100            provenance: HashMap::default(),
101        }
102    }
103
104    pub fn check(&mut self) {
105        if self.func.basic_blocks.is_empty() {
106            return;
107        }
108
109        let num_blocks = self.func.basic_blocks.len();
110        // Flat local space: return slot, parameters, then temporaries. The
111        // return slot and parameters have no entry in `func.locals`.
112        let num_slots = self.func.next_local();
113
114        let mut entry_states: Vec<Option<FlowState>> = vec![None; num_blocks];
115
116        let mut init_state = FlowState::new(num_slots);
117        let _ = init_state.set(Local(0), LocalState::Initialized);
118        for i in 1..=self.func.arg_count {
119            let _ = init_state.set(Local(i), LocalState::Initialized);
120        }
121        entry_states[0] = Some(init_state);
122
123        let mut worklist: VecDeque<usize> = VecDeque::new();
124        worklist.push_back(0);
125        let mut in_worklist: Vec<bool> = vec![false; num_blocks];
126        in_worklist[0] = true;
127
128        while let Some(bb_idx) = worklist.pop_front() {
129            in_worklist[bb_idx] = false;
130
131            let state = match &entry_states[bb_idx] {
132                Some(s) => s.clone(),
133                None => continue,
134            };
135
136            let mut state = state;
137            let bb = &self.func.basic_blocks[bb_idx];
138
139            for (stmt_idx, stmt) in bb.statements.iter().enumerate() {
140                self.transfer_stmt(stmt, &mut state);
141                let live_after = &self.liveness.live_after[bb_idx][stmt_idx + 1];
142                self.release_dead_borrows(&mut state, live_after);
143            }
144
145            let term_idx = bb.statements.len();
146            if let Some(term) = &bb.terminator {
147                self.check_terminator(term, &mut state);
148            }
149
150            let live_after_term = &self.liveness.live_after[bb_idx][term_idx];
151            self.release_dead_borrows(&mut state, live_after_term);
152
153            let successors = self.successors(bb);
154            for succ in successors {
155                let changed = match &mut entry_states[succ] {
156                    None => {
157                        entry_states[succ] = Some(state.clone());
158                        true
159                    },
160                    Some(existing) => existing.join(&state),
161                };
162                if changed && !in_worklist[succ] {
163                    worklist.push_back(succ);
164                    in_worklist[succ] = true;
165                }
166            }
167        }
168    }
169
170    fn transfer_stmt(&mut self, stmt: &Statement, state: &mut FlowState) {
171        match &stmt.kind {
172            StatementKind::Assign(lhs, rvalue) => {
173                self.check_rvalue(rvalue, state, stmt.span);
174
175                match rvalue {
176                    Rvalue::Ref(rhs) | Rvalue::MutRef(rhs) => {
177                        self.provenance.insert(*lhs, *rhs);
178                    },
179                    Rvalue::Use(Operand::Copy(rhs)) | Rvalue::Use(Operand::Move(rhs)) => {
180                        if let Some(prov) = self.provenance.get(rhs).cloned() {
181                            self.provenance.insert(*lhs, prov);
182                        }
183                    },
184                    _ => {
185                        self.provenance.remove(lhs);
186                    },
187                }
188
189                if let Err(msg) = state.set(*lhs, LocalState::Initialized) {
190                    let name = self.local_name(*lhs);
191                    self.errors.push(format!("{} `{}`", msg, name));
192                }
193            },
194            StatementKind::SetAttr(obj, _field, val) => {
195                self.check_mutation(obj, state, stmt.span);
196                self.check_operand(obj, state, stmt.span);
197                self.check_operand(val, state, stmt.span);
198            },
199            StatementKind::SetIndex(obj, idx, val) => {
200                self.check_mutation(obj, state, stmt.span);
201                self.check_operand(obj, state, stmt.span);
202                self.check_operand(idx, state, stmt.span);
203                self.check_operand(val, state, stmt.span);
204            },
205            StatementKind::StorageLive(local) => {
206                let _ = state.set(*local, LocalState::Initialized);
207            },
208            StatementKind::StorageDead(local) => {
209                let _ = state.set(*local, LocalState::Dead);
210            },
211            StatementKind::Drop(local) => {
212                if state.get(*local) == LocalState::Initialized {
213                    if let Err(msg) = state.set(*local, LocalState::Moved) {
214                        let name = self.local_name(*local);
215                        self.errors
216                            .push(format!("{} `{}` (lifetime error)", msg, name));
217                    }
218                }
219            },
220            StatementKind::VectorStore(obj, idx, val) => {
221                self.check_mutation(obj, state, stmt.span);
222                self.check_operand(obj, state, stmt.span);
223                self.check_operand(idx, state, stmt.span);
224                self.check_operand(val, state, stmt.span);
225            },
226        }
227    }
228
229    fn check_terminator(&mut self, term: &Terminator, state: &mut FlowState) {
230        match &term.kind {
231            TerminatorKind::SwitchInt { discr, .. } => {
232                self.check_operand(discr, state, term.span);
233            },
234            TerminatorKind::Return | TerminatorKind::Goto { .. } | TerminatorKind::Unreachable => {
235            },
236        }
237    }
238
239    fn successors(&self, bb: &BasicBlock) -> Vec<usize> {
240        match &bb.terminator {
241            Some(t) => match &t.kind {
242                TerminatorKind::Goto { target } => vec![target.0],
243                TerminatorKind::SwitchInt { targets, otherwise, .. } => {
244                    let mut succs: Vec<usize> = targets.iter().map(|(_, bb)| bb.0).collect();
245                    succs.push(otherwise.0);
246                    succs
247                },
248                TerminatorKind::Return | TerminatorKind::Unreachable => vec![],
249            },
250            None => vec![],
251        }
252    }
253
254    fn check_rvalue(&mut self, rvalue: &Rvalue, state: &mut FlowState, span: Span) {
255        match rvalue {
256            Rvalue::Use(op) => self.check_operand(op, state, span),
257            Rvalue::BinaryOp(_, lhs, rhs) => {
258                self.check_operand(lhs, state, span);
259                self.check_operand(rhs, state, span);
260            },
261            Rvalue::UnaryOp(_, op) => self.check_operand(op, state, span),
262            Rvalue::Call { func, args } => {
263                self.check_operand(func, state, span);
264                for arg in args {
265                    self.check_operand(arg, state, span);
266                }
267            },
268            Rvalue::Aggregate(_, ops) => {
269                for op in ops {
270                    self.check_operand(op, state, span);
271                }
272            },
273            Rvalue::GetAttr(op, _) => self.check_operand(op, state, span),
274            Rvalue::GetIndex(obj, idx) => {
275                self.check_operand(obj, state, span);
276                self.check_operand(idx, state, span);
277            },
278            Rvalue::Ref(local) => {
279                let s = state.get(*local);
280                if s != LocalState::Initialized {
281                    let name = self.local_name(*local);
282                    self.errors.push(format!(
283                        "cannot borrow uninitialized or moved variable `{}`",
284                        name
285                    ));
286                }
287                if local.0 < state.borrows.len() {
288                    let borrow = &mut state.borrows[local.0];
289                    if borrow.1 {
290                        let name = self.local_name(*local);
291                        self.errors.push(format!(
292                            "cannot borrow `{}` as immutable because it is also borrowed as mutable",
293                            name
294                        ));
295                    }
296                    borrow.0 += 1;
297                }
298            },
299            Rvalue::MutRef(local) => {
300                let s = state.get(*local);
301                if s != LocalState::Initialized {
302                    let name = self.local_name(*local);
303                    self.errors.push(format!(
304                        "cannot mutably borrow uninitialized or moved variable `{}`",
305                        name
306                    ));
307                }
308                if let Some(decl) = self.func.local_decl(*local) {
309                    if !decl.is_mut {
310                        let name = self.local_name(*local);
311                        self.errors.push(format!(
312                            "cannot mutably borrow immutable variable `{}`",
313                            name
314                        ));
315                    }
316                }
317                if local.0 < state.borrows.len() {
318                    let borrow = &mut state.borrows[local.0];
319                    if borrow.0 > 0 || borrow.1 {
320                        let name = self.local_name(*local);
321                        self.errors.push(format!(
322                            "cannot borrow `{}` as mutable because it is already borrowed",
323                            name
324                        ));
325                    }
326                    borrow.1 = true;
327                }
328            },
329            Rvalue::VectorSplat(op, _) => self.check_operand(op, state, span),
330            Rvalue::VectorLoad(obj, idx, _) => {
331                self.check_operand(obj, state, span);
332                self.check_operand(idx, state, span);
333            },
334            Rvalue::VectorFMA(a, b, c) => {
335                self.check_operand(a, state, span);
336                self.check_operand(b, state, span);
337                self.check_operand(c, state, span);
338            },
339        }
340    }
341
342    fn check_mutation(&mut self, op: &Operand, state: &FlowState, _span: Span) {
343        if let Operand::Copy(local) | Operand::Move(local) = op {
344            if local.0 < state.borrows.len() {
345                let borrow = state.borrows[local.0];
346                if borrow.0 > 0 || borrow.1 {
347                    let name = self.local_name(*local);
348                    self.errors
349                        .push(format!("cannot mutate `{}` because it is borrowed", name));
350                }
351            }
352        }
353    }
354
355    fn check_operand(&mut self, op: &Operand, state: &mut FlowState, _span: Span) {
356        match op {
357            Operand::Copy(local) | Operand::Move(local) => match state.get(*local) {
358                LocalState::Dead => {
359                    let is_unnamed = self
360                        .func
361                        .local_decl(*local)
362                        .map_or(true, |d| d.name.is_none());
363                    if is_unnamed {
364                        return;
365                    }
366                    let name = self.local_name(*local);
367                    self.errors
368                        .push(format!("use of possibly uninitialized variable `{}`", name));
369                },
370                LocalState::Moved => {
371                    let name = self.local_name(*local);
372                    self.errors
373                        .push(format!("use of moved variable `{}`", name));
374                },
375                LocalState::Initialized => {
376                    if let Operand::Move(local) = op {
377                        if self
378                            .func
379                            .local_decl(*local)
380                            .is_some_and(|d| d.ty.is_move_type())
381                        {
382                            if let Err(msg) = state.set(*local, LocalState::Moved) {
383                                let name = self.local_name(*local);
384                                self.errors.push(format!("{} `{}`", msg, name));
385                            }
386                        }
387                    }
388                },
389            },
390            Operand::Constant(_) => {},
391        }
392    }
393
394    fn release_dead_borrows(&self, state: &mut FlowState, live_locals: &HashSet<Local>) {
395        let mut still_borrowed = HashSet::new();
396        for (ref_var, pointed_var) in &self.provenance {
397            if live_locals.contains(ref_var) {
398                still_borrowed.insert(*pointed_var);
399            }
400        }
401
402        for (ref_var, pointed_var) in &self.provenance {
403            if !live_locals.contains(ref_var) && !still_borrowed.contains(pointed_var) {
404                if pointed_var.0 < state.borrows.len() {
405                    let borrow = &mut state.borrows[pointed_var.0];
406                    if borrow.0 > 0 {
407                        borrow.0 -= 1;
408                    } else {
409                        borrow.1 = false;
410                    }
411                }
412            }
413        }
414    }
415
416    fn local_name(&self, local: Local) -> String {
417        self.func
418            .local_decl(local)
419            .and_then(|decl| decl.name.as_ref())
420            .cloned()
421            .unwrap_or_else(|| format!("_{}", local.0))
422    }
423}