use std::collections::HashMap;
use crate::common::span::Spanned;
use crate::compiler::{
cst::{CST, CSTPattern},
sst::{SST, SSTPattern, UniqueSymbol, Scope},
syntax::Syntax,
};
pub fn hoist(cst: Spanned<CST>) -> Result<(Spanned<SST>, Scope), Syntax> {
let mut hoister = Hoister::new();
let sst = hoister.walk(cst)?;
let scope = hoister.scopes.pop().unwrap();
if !hoister.unresolved_hoists.is_empty() {
return Err(Syntax::error(
&format!(
"{} were referenced before assignment",
hoister.unresolved_hoists.keys()
.map(|s| format!("'{}'", s))
.collect::<Vec<String>>()
.join(", ")
),
&sst.span,
))
}
return Ok((sst, scope));
}
pub struct Hoister {
scopes: Vec<Scope>,
symbol_table: Vec<String>,
unresolved_hoists: HashMap<String, UniqueSymbol>,
}
impl Hoister {
pub fn new() -> Hoister {
Hoister {
scopes: vec![Scope::new()],
symbol_table: vec![],
unresolved_hoists: HashMap::new(),
}
}
fn enter_scope(&mut self) { self.scopes.push(Scope::new()); }
fn reenter_scope(&mut self, scope: Scope) { self.scopes.push(scope) }
fn exit_scope(&mut self) -> Option<Scope> {
if self.scopes.len() == 1 { return None; }
return self.scopes.pop()
}
fn local_scope(&mut self) -> &mut Scope {
let last = self.scopes.len() - 1;
&mut self.scopes[last]
}
fn borrow_local_scope(&self) -> &Scope {
let last = self.scopes.len() - 1;
&self.scopes[last]
}
pub fn walk(&mut self, cst: Spanned<CST>) -> Result<Spanned<SST>, Syntax> {
let sst: SST = match cst.item {
CST::Data(data) => SST::Data(data),
CST::Symbol(name) => self.symbol(&name),
CST::Block(block) => self.block(block)?,
CST::Label(name, expression) => SST::Label(name, Box::new(self.walk(*expression)?)),
CST::Tuple(tuple) => self.tuple(tuple)?,
CST::FFI { name, expression } => SST::ffi(&name, self.walk(*expression)?),
CST::Assign { pattern, expression } => self.assign(*pattern, *expression)?,
CST::Lambda { pattern, expression } => self.lambda(*pattern, *expression)?,
CST::Call { fun, arg } => self.call(*fun, *arg)?,
};
return Ok(Spanned::new(sst, cst.span))
}
pub fn walk_pattern(&mut self, pattern: Spanned<CSTPattern>, declare: bool) -> Spanned<SSTPattern> {
let item = match pattern.item {
CSTPattern::Symbol(name) => {
SSTPattern::Symbol(self.resolve_assign(&name, declare))
},
CSTPattern::Data(d) => SSTPattern::Data(d),
CSTPattern::Label(n, p) => SSTPattern::Label(n, Box::new(self.walk_pattern(*p, declare))),
CSTPattern::Tuple(t) => SSTPattern::Tuple(
t.into_iter().map(|c| self.walk_pattern(c, declare)).collect::<Vec<_>>()
)
};
return Spanned::new(item, pattern.span);
}
fn new_symbol(&mut self, name: &str) -> UniqueSymbol {
let index = self.symbol_table.len();
self.symbol_table.push(name.to_string());
return UniqueSymbol(index);
}
fn local_symbol(&self, name: &str) -> Option<UniqueSymbol> {
for local in self.borrow_local_scope().locals.iter() {
let local_name = &self.symbol_table[local.0];
if local_name == name { return Some(*local); }
}
return None;
}
fn nonlocal_symbol(&self, name: &str) -> Option<UniqueSymbol> {
for nonlocal in self.borrow_local_scope().nonlocals.iter() {
let nonlocal_name = &self.symbol_table[nonlocal.0];
if nonlocal_name == name { return Some(*nonlocal); }
}
return None;
}
fn capture_all(&mut self, unique_symbol: UniqueSymbol) {
for scope in self.scopes.iter_mut() {
scope.nonlocals.push(unique_symbol);
}
}
fn uncapture_all(&mut self, unique_symbol: UniqueSymbol) {
for scope in self.scopes.iter_mut() {
let index = scope.nonlocal_index(unique_symbol).unwrap();
scope.nonlocals.remove(index);
}
}
fn try_resolve(&mut self, name: &str) -> Option<UniqueSymbol> {
if let Some(unique_symbol) = self.local_symbol(name) { return Some(unique_symbol); }
if let Some(unique_symbol) = self.nonlocal_symbol(name) { return Some(unique_symbol); }
if let Some(scope) = self.exit_scope() {
let resolved = self.try_resolve(name);
self.reenter_scope(scope);
if let Some(unique_symbol) = resolved {
self.local_scope().nonlocals.push(unique_symbol);
return Some(unique_symbol);
}
}
return None;
}
fn resolve_assign(&mut self, name: &str, redeclare: bool) -> UniqueSymbol {
if let Some(unique_symbol) = self.unresolved_hoists.get(name) {
let unique_symbol = *unique_symbol;
self.uncapture_all(unique_symbol);
self.unresolved_hoists.remove(name);
self.local_scope().locals.push(unique_symbol);
return unique_symbol;
}
if !redeclare {
if let Some(unique_symbol) = self.try_resolve(name) { return unique_symbol; }
}
let unique_symbol = self.new_symbol(name);
self.local_scope().locals.push(unique_symbol);
return unique_symbol;
}
fn resolve_symbol(&mut self, name: &str) -> UniqueSymbol {
if let Some(unique_symbol) = self.unresolved_hoists.get(name) {
return *unique_symbol;
}
if let Some(unique_symbol) = self.try_resolve(name) { return unique_symbol; }
let unique_symbol = self.new_symbol(name);
self.capture_all(unique_symbol);
self.unresolved_hoists.insert(name.to_string(), unique_symbol);
return unique_symbol;
}
pub fn symbol(&mut self, name: &str) -> SST {
return SST::Symbol(self.resolve_symbol(name));
}
pub fn block(&mut self, block: Vec<Spanned<CST>>) -> Result<SST, Syntax> {
let mut expressions = vec![];
for expression in block {
expressions.push(self.walk(expression)?)
}
Ok(SST::Block(expressions))
}
pub fn tuple(&mut self, tuple: Vec<Spanned<CST>>) -> Result<SST, Syntax> {
let mut expressions = vec![];
for expression in tuple {
expressions.push(self.walk(expression)?)
}
Ok(SST::Tuple(expressions))
}
pub fn assign(&mut self, pattern: Spanned<CSTPattern>, expression: Spanned<CST>) -> Result<SST, Syntax> {
let sst_pattern = self.walk_pattern(pattern, false);
let sst_expression = self.walk(expression)?;
return Ok(SST::assign(
sst_pattern,
sst_expression,
));
}
pub fn lambda(&mut self, pattern: Spanned<CSTPattern>, expression: Spanned<CST>) -> Result<SST, Syntax> {
self.enter_scope();
let sst_pattern = self.walk_pattern(pattern, true);
let sst_expression = self.walk(expression)?;
let scope = self.exit_scope().unwrap();
return Ok(SST::lambda(
sst_pattern,
sst_expression,
scope,
));
}
pub fn call(&mut self, fun: Spanned<CST>, arg: Spanned<CST>) -> Result<SST, Syntax> {
return Ok(SST::call(
self.walk(fun)?,
self.walk(arg)?,
));
}
}