use std::collections::HashMap;
use crate::ast::stmt::{Stmt, Expr, Block};
use crate::intern::{Interner, Symbol};
use crate::token::Span;
#[derive(Debug, Clone)]
pub struct EscapeError {
pub kind: EscapeErrorKind,
pub span: Span,
}
#[derive(Debug, Clone)]
pub enum EscapeErrorKind {
ReturnEscape {
variable: String,
zone_name: String,
},
AssignmentEscape {
variable: String,
target: String,
zone_name: String,
},
}
impl std::fmt::Display for EscapeError {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
match &self.kind {
EscapeErrorKind::ReturnEscape { variable, zone_name } => {
write!(
f,
"Reference '{}' cannot escape zone '{}'.\n\n\
Variables allocated inside a zone are deallocated when the zone ends.\n\
Returning them would create a dangling reference.\n\n\
Tip: Copy the data if you need it outside the zone.",
variable, zone_name
)
}
EscapeErrorKind::AssignmentEscape { variable, target, zone_name } => {
write!(
f,
"Reference '{}' cannot escape zone '{}' via assignment to '{}'.\n\n\
Variables allocated inside a zone are deallocated when the zone ends.\n\
Assigning them to outer scope variables would create a dangling reference.\n\n\
Tip: Copy the data if you need it outside the zone.",
variable, zone_name, target
)
}
}
}
}
impl std::error::Error for EscapeError {}
pub struct EscapeChecker<'a> {
zone_depth: HashMap<Symbol, usize>,
current_depth: usize,
zone_stack: Vec<Symbol>,
interner: &'a Interner,
}
impl<'a> EscapeChecker<'a> {
pub fn new(interner: &'a Interner) -> Self {
Self {
zone_depth: HashMap::new(),
current_depth: 0,
zone_stack: Vec::new(),
interner,
}
}
pub fn check_program(&mut self, stmts: &[Stmt<'_>]) -> Result<(), EscapeError> {
self.check_block(stmts)
}
fn check_block(&mut self, stmts: &[Stmt<'_>]) -> Result<(), EscapeError> {
for stmt in stmts {
self.check_stmt(stmt)?;
}
Ok(())
}
fn check_stmt(&mut self, stmt: &Stmt<'_>) -> Result<(), EscapeError> {
match stmt {
Stmt::Zone { name, body, .. } => {
self.current_depth += 1;
self.zone_stack.push(*name);
self.check_block(body)?;
self.zone_stack.pop();
self.current_depth -= 1;
}
Stmt::Let { var, .. } => {
self.zone_depth.insert(*var, self.current_depth);
}
Stmt::Return { value: Some(expr) } => {
self.check_no_escape(expr, 0)?;
}
Stmt::Set { target, value } => {
let target_depth = self.zone_depth.get(target).copied().unwrap_or(0);
self.check_no_escape_with_target(value, target_depth, *target)?;
}
Stmt::If { then_block, else_block, .. } => {
self.check_block(then_block)?;
if let Some(else_b) = else_block {
self.check_block(else_b)?;
}
}
Stmt::While { body, .. } => {
self.check_block(body)?;
}
Stmt::Repeat { body, .. } => {
self.check_block(body)?;
}
Stmt::Inspect { arms, .. } => {
for arm in arms {
self.check_block(arm.body)?;
}
}
Stmt::Escape { .. } => {}
_ => {}
}
Ok(())
}
fn check_no_escape(&self, expr: &Expr<'_>, max_depth: usize) -> Result<(), EscapeError> {
match expr {
Expr::Identifier(sym) => {
if let Some(&depth) = self.zone_depth.get(sym) {
if depth > max_depth && depth > 0 {
let zone_name = self.zone_stack.get(depth - 1)
.map(|s| self.interner.resolve(*s).to_string())
.unwrap_or_else(|| "unknown".to_string());
let var_name = self.interner.resolve(*sym).to_string();
return Err(EscapeError {
kind: EscapeErrorKind::ReturnEscape {
variable: var_name,
zone_name,
},
span: Span::default(),
});
}
}
}
Expr::BinaryOp { left, right, .. } => {
self.check_no_escape(left, max_depth)?;
self.check_no_escape(right, max_depth)?;
}
Expr::Call { args, .. } => {
for arg in args {
self.check_no_escape(arg, max_depth)?;
}
}
Expr::FieldAccess { object, .. } => {
self.check_no_escape(object, max_depth)?;
}
Expr::Index { collection, index } => {
self.check_no_escape(collection, max_depth)?;
self.check_no_escape(index, max_depth)?;
}
Expr::Slice { collection, start, end } => {
self.check_no_escape(collection, max_depth)?;
self.check_no_escape(start, max_depth)?;
self.check_no_escape(end, max_depth)?;
}
Expr::Copy { expr } | Expr::Give { value: expr } | Expr::Length { collection: expr }
| Expr::Not { operand: expr } => {
self.check_no_escape(expr, max_depth)?;
}
Expr::List(items) | Expr::Tuple(items) => {
for item in items {
self.check_no_escape(item, max_depth)?;
}
}
Expr::Range { start, end } => {
self.check_no_escape(start, max_depth)?;
self.check_no_escape(end, max_depth)?;
}
Expr::New { init_fields, .. } => {
for (_, expr) in init_fields {
self.check_no_escape(expr, max_depth)?;
}
}
Expr::NewVariant { fields, .. } => {
for (_, expr) in fields {
self.check_no_escape(expr, max_depth)?;
}
}
Expr::ManifestOf { zone } => {
self.check_no_escape(zone, max_depth)?;
}
Expr::ChunkAt { index, zone } => {
self.check_no_escape(index, max_depth)?;
self.check_no_escape(zone, max_depth)?;
}
Expr::Contains { collection, value } => {
self.check_no_escape(collection, max_depth)?;
self.check_no_escape(value, max_depth)?;
}
Expr::Union { left, right } | Expr::Intersection { left, right } => {
self.check_no_escape(left, max_depth)?;
self.check_no_escape(right, max_depth)?;
}
Expr::WithCapacity { value, capacity } => {
self.check_no_escape(value, max_depth)?;
self.check_no_escape(capacity, max_depth)?;
}
Expr::OptionSome { value } => {
self.check_no_escape(value, max_depth)?;
}
Expr::OptionNone => {}
Expr::Escape { .. } => {}
Expr::Closure { body, .. } => {
match body {
crate::ast::stmt::ClosureBody::Expression(expr) => {
self.check_no_escape(expr, max_depth)?;
}
crate::ast::stmt::ClosureBody::Block(_) => {
}
}
}
Expr::CallExpr { callee, args } => {
self.check_no_escape(callee, max_depth)?;
for arg in args {
self.check_no_escape(arg, max_depth)?;
}
}
Expr::InterpolatedString(parts) => {
for part in parts {
if let crate::ast::stmt::StringPart::Expr { value, .. } = part {
self.check_no_escape(value, max_depth)?;
}
}
}
Expr::Literal(_) => {}
}
Ok(())
}
fn check_no_escape_with_target(
&self,
expr: &Expr<'_>,
max_depth: usize,
target: Symbol,
) -> Result<(), EscapeError> {
match expr {
Expr::Identifier(sym) => {
if let Some(&depth) = self.zone_depth.get(sym) {
if depth > max_depth && depth > 0 {
let zone_name = self.zone_stack.get(depth - 1)
.map(|s| self.interner.resolve(*s).to_string())
.unwrap_or_else(|| "unknown".to_string());
let var_name = self.interner.resolve(*sym).to_string();
let target_name = self.interner.resolve(target).to_string();
return Err(EscapeError {
kind: EscapeErrorKind::AssignmentEscape {
variable: var_name,
target: target_name,
zone_name,
},
span: Span::default(),
});
}
}
}
_ => self.check_no_escape(expr, max_depth)?,
}
Ok(())
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_escape_checker_basic() {
use crate::intern::Interner;
let mut interner = Interner::new();
let checker = EscapeChecker::new(&interner);
assert_eq!(checker.current_depth, 0);
assert!(checker.zone_depth.is_empty());
}
}