use crate::ast::{ExprType, VarRef};
use crate::bytecode::{Register, RegisterScope};
use crate::compiler::ids::HashMapWithIds;
use crate::{CallableMetadata, bytecode};
use std::cell::Cell;
use std::cell::RefCell;
use std::cmp::max;
use std::collections::HashMap;
use std::convert::TryFrom;
use std::fmt;
use std::rc::Rc;
#[derive(Clone, Copy, Debug, PartialEq)]
pub(super) struct ArrayInfo {
pub(super) subtype: ExprType,
pub(super) ndims: usize,
}
#[derive(Clone, Copy, Debug, PartialEq)]
pub(super) enum SymbolPrototype {
Array(ArrayInfo),
Scalar(ExprType),
}
#[derive(Debug, thiserror::Error)]
#[allow(missing_docs)] pub(super) enum Error {
#[error("Cannot redefine {0}")]
AlreadyDefined(VarRef),
#[error("Incompatible type annotation in {0} reference")]
IncompatibleTypeAnnotationInReference(VarRef),
#[error("Out of {0} registers")]
OutOfRegisters(RegisterScope),
#[error("Undefined {1} symbol {0}")]
UndefinedSymbol(VarRef, RegisterScope),
}
type Result<T> = std::result::Result<T, Error>;
#[derive(Clone, Debug, Hash, Eq, Ord, PartialEq, PartialOrd)]
pub struct SymbolKey(String);
impl<R: AsRef<str>> From<R> for SymbolKey {
fn from(value: R) -> Self {
Self(value.as_ref().to_ascii_uppercase())
}
}
impl fmt::Display for SymbolKey {
fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
write!(f, "{}", self.0)
}
}
fn get_var<MKR>(
vref: &VarRef,
table: &HashMapWithIds<SymbolKey, SymbolPrototype, u8>,
make_register: MKR,
scope: RegisterScope,
) -> Result<(Register, SymbolPrototype)>
where
MKR: FnOnce(u8) -> std::result::Result<Register, bytecode::OutOfRegistersError>,
{
let key = SymbolKey::from(&vref.name);
match table.get(&key) {
Some((SymbolPrototype::Array(info), reg)) => {
if !vref.accepts(info.subtype) {
return Err(Error::IncompatibleTypeAnnotationInReference(vref.clone()));
}
let reg = make_register(reg).map_err(|_| Error::OutOfRegisters(scope))?;
Ok((reg, SymbolPrototype::Array(*info)))
}
Some((SymbolPrototype::Scalar(etype), reg)) => {
if !vref.accepts(*etype) {
return Err(Error::IncompatibleTypeAnnotationInReference(vref.clone()));
}
let reg = make_register(reg).map_err(|_| Error::OutOfRegisters(scope))?;
Ok((reg, SymbolPrototype::Scalar(*etype)))
}
None => Err(Error::UndefinedSymbol(vref.clone(), scope)),
}
}
fn put_var<MKR>(
key: SymbolKey,
proto: SymbolPrototype,
table: &mut HashMapWithIds<SymbolKey, SymbolPrototype, u8>,
make_register: MKR,
scope: RegisterScope,
) -> Result<Register>
where
MKR: FnOnce(u8) -> std::result::Result<Register, bytecode::OutOfRegistersError>,
{
match table.insert(key, proto) {
Some((None, reg)) => Ok(make_register(reg).map_err(|_| Error::OutOfRegisters(scope))?),
Some((Some(_old_proto), _reg)) => {
unreachable!("Cannot redefine symbol; caller must check for presence first");
}
None => Err(Error::OutOfRegisters(scope)),
}
}
#[derive(Clone)]
pub(crate) struct GlobalSymtable {
globals: HashMapWithIds<SymbolKey, SymbolPrototype, u8>,
upcalls: HashMap<SymbolKey, Rc<CallableMetadata>>,
user_callables: HashMap<SymbolKey, Rc<CallableMetadata>>,
}
impl GlobalSymtable {
pub(crate) fn new(upcalls: HashMap<SymbolKey, Rc<CallableMetadata>>) -> Self {
Self { globals: HashMapWithIds::default(), upcalls, user_callables: HashMap::default() }
}
pub(crate) fn enter_scope(&mut self) -> LocalSymtable<'_> {
LocalSymtable::new(self)
}
pub(crate) fn get_global(&self, vref: &VarRef) -> Result<(Register, SymbolPrototype)> {
get_var(vref, &self.globals, Register::global, RegisterScope::Global)
}
pub(crate) fn put_global(
&mut self,
key: SymbolKey,
proto: SymbolPrototype,
) -> Result<Register> {
put_var(key, proto, &mut self.globals, Register::global, RegisterScope::Global)
}
pub(crate) fn contains_global(&self, key: &SymbolKey) -> bool {
self.globals.get(key).is_some()
}
pub(crate) fn iter_globals(
&self,
) -> impl Iterator<Item = (&SymbolKey, SymbolPrototype, u8)> + '_ {
self.globals.iter().map(|(k, v, i)| (k, *v, i))
}
pub(crate) fn declare_user_callable(
&mut self,
vref: &VarRef,
md: Rc<CallableMetadata>,
) -> Result<()> {
let key = SymbolKey::from(&vref.name);
if self.globals.get(&key).is_some() {
return Err(Error::AlreadyDefined(vref.clone()));
}
if let Some(previous_md) = self.user_callables.insert(key.clone(), md.clone())
&& previous_md != md
{
return Err(Error::AlreadyDefined(vref.clone()));
}
Ok(())
}
pub(crate) fn get_callable(&self, key: &SymbolKey) -> Option<Rc<CallableMetadata>> {
self.user_callables.get(key).or(self.upcalls.get(key)).cloned()
}
}
#[derive(Clone, Default)]
pub(crate) struct LocalSymtableSnapshot {
locals: HashMapWithIds<SymbolKey, SymbolPrototype, u8>,
count_temps: u8,
active_temps: Rc<Cell<u8>>,
}
pub(crate) struct LocalSymtable<'a> {
symtable: &'a mut GlobalSymtable,
locals: HashMapWithIds<SymbolKey, SymbolPrototype, u8>,
count_temps: u8,
active_temps: Rc<Cell<u8>>,
}
impl<'a> LocalSymtable<'a> {
fn new(symtable: &'a mut GlobalSymtable) -> Self {
Self {
symtable,
locals: HashMapWithIds::default(),
count_temps: 0,
active_temps: Rc::from(Cell::new(0)),
}
}
pub(crate) fn save(self) -> LocalSymtableSnapshot {
LocalSymtableSnapshot {
locals: self.locals,
count_temps: self.count_temps,
active_temps: self.active_temps,
}
}
pub(crate) fn restore(
symtable: &'a mut GlobalSymtable,
snapshot: LocalSymtableSnapshot,
) -> Self {
Self {
symtable,
locals: snapshot.locals,
count_temps: snapshot.count_temps,
active_temps: snapshot.active_temps,
}
}
pub(crate) fn global(&mut self) -> &mut GlobalSymtable {
self.symtable
}
pub(crate) fn declare_user_callable(
&mut self,
vref: &VarRef,
md: Rc<CallableMetadata>,
) -> Result<()> {
self.symtable.declare_user_callable(vref, md)
}
pub(crate) fn frozen(&mut self) -> TempSymtable<'_, 'a> {
TempSymtable::new(self)
}
pub(crate) fn with_reserved_temp<T, E, ME, F>(
&mut self,
map_error: ME,
f: F,
) -> std::result::Result<T, E>
where
ME: Fn(Error) -> E,
F: FnOnce(Register, &mut TempSymtable<'_, 'a>) -> std::result::Result<T, E>,
{
struct TempReservationGuard {
active_temps: Rc<Cell<u8>>,
}
impl Drop for TempReservationGuard {
fn drop(&mut self) {
let active_temps = self.active_temps.get();
debug_assert!(active_temps > 0);
self.active_temps.set(active_temps - 1);
}
}
let nlocals = u8::try_from(self.locals.len())
.map_err(|_| map_error(Error::OutOfRegisters(RegisterScope::Local)))?;
let first_temp = self.active_temps.get();
let new_active_temps = first_temp
.checked_add(1)
.ok_or(map_error(Error::OutOfRegisters(RegisterScope::Temp)))?;
self.active_temps.set(new_active_temps);
self.count_temps = max(self.count_temps, new_active_temps);
let _guard = TempReservationGuard { active_temps: self.active_temps.clone() };
let reg_idx = u8::try_from(usize::from(nlocals) + usize::from(first_temp))
.map_err(|_| map_error(Error::OutOfRegisters(RegisterScope::Temp)))?;
let reg = Register::local(reg_idx)
.map_err(|_| map_error(Error::OutOfRegisters(RegisterScope::Temp)))?;
let mut temp = self.frozen();
f(reg, &mut temp)
}
pub(crate) fn put_global(
&mut self,
key: SymbolKey,
proto: SymbolPrototype,
) -> Result<Register> {
self.symtable.put_global(key, proto)
}
pub(crate) fn get_local_or_global(&self, vref: &VarRef) -> Result<(Register, SymbolPrototype)> {
match get_var(vref, &self.locals, Register::local, RegisterScope::Local) {
Ok(local) => Ok(local),
Err(Error::UndefinedSymbol(..)) => self.symtable.get_global(vref),
Err(e) => Err(e),
}
}
pub(crate) fn get_callable(&self, key: &SymbolKey) -> Option<Rc<CallableMetadata>> {
self.symtable.get_callable(key)
}
pub(crate) fn put_local(&mut self, key: SymbolKey, proto: SymbolPrototype) -> Result<Register> {
put_var(key, proto, &mut self.locals, Register::local, RegisterScope::Local)
}
pub(crate) fn contains_local(&self, key: &SymbolKey) -> bool {
self.locals.get(key).is_some()
}
pub(crate) fn contains_global(&self, key: &SymbolKey) -> bool {
self.symtable.contains_global(key)
}
pub(crate) fn iter_globals(
&self,
) -> impl Iterator<Item = (SymbolKey, SymbolPrototype, u8)> + '_ {
self.symtable.iter_globals().map(|(k, v, i)| (k.clone(), v, i))
}
pub(crate) fn iter_locals(
&self,
) -> impl Iterator<Item = (SymbolKey, SymbolPrototype, u8)> + '_ {
self.locals.iter().map(|(k, v, i)| (k.clone(), *v, i))
}
pub(crate) fn fixup_local_type(&mut self, vref: &VarRef, new_etype: ExprType) -> Result<()> {
let key = SymbolKey::from(&vref.name);
match self.locals.get_mut(&key) {
Some((SymbolPrototype::Array(_), _)) | None => {
Err(Error::UndefinedSymbol(vref.clone(), RegisterScope::Local))
}
Some((SymbolPrototype::Scalar(etype), _reg)) => {
*etype = new_etype;
Ok(())
}
}
}
}
pub(crate) struct TempSymtable<'temp, 'local> {
symtable: &'temp mut LocalSymtable<'local>,
base_temp: u8,
next_temp: Rc<RefCell<u8>>,
count_temps: Rc<RefCell<u8>>,
}
impl<'temp, 'local> Drop for TempSymtable<'temp, 'local> {
fn drop(&mut self) {
debug_assert_eq!(self.base_temp, *self.next_temp.borrow(), "Unbalanced temp drops");
self.symtable.count_temps = max(self.symtable.count_temps, *self.count_temps.borrow());
}
}
impl<'temp, 'local> TempSymtable<'temp, 'local> {
fn new(symtable: &'temp mut LocalSymtable<'local>) -> Self {
let base_temp = symtable.active_temps.get();
Self {
symtable,
base_temp,
next_temp: Rc::from(RefCell::from(base_temp)),
count_temps: Rc::from(RefCell::from(base_temp)),
}
}
pub(crate) fn get_local_or_global(&self, vref: &VarRef) -> Result<(Register, SymbolPrototype)> {
self.symtable.get_local_or_global(vref)
}
pub(crate) fn get_callable(&self, key: &SymbolKey) -> Option<Rc<CallableMetadata>> {
self.symtable.get_callable(key)
}
pub(crate) fn temp_scope(&self) -> TempScope {
let nlocals = u8::try_from(self.symtable.locals.len())
.expect("Cannot have allocated more locals than u8");
TempScope {
base_temp: *self.next_temp.borrow(),
nlocals,
ntemps: 0,
next_temp: self.next_temp.clone(),
count_temps: self.count_temps.clone(),
}
}
}
pub(crate) struct TempScope {
base_temp: u8,
nlocals: u8,
ntemps: u8,
next_temp: Rc<RefCell<u8>>,
count_temps: Rc<RefCell<u8>>,
}
impl Drop for TempScope {
fn drop(&mut self) {
let mut next_temp = self.next_temp.borrow_mut();
debug_assert!(*next_temp >= self.ntemps);
*next_temp -= self.ntemps;
}
}
impl TempScope {
pub(crate) fn first(&mut self) -> Result<Register> {
let reg = u8::try_from(usize::from(self.nlocals) + usize::from(self.base_temp))
.map_err(|_| Error::OutOfRegisters(RegisterScope::Temp))?;
Register::local(reg).map_err(|_| Error::OutOfRegisters(RegisterScope::Temp))
}
pub(crate) fn alloc(&mut self) -> Result<Register> {
let temp;
let new_next_temp;
{
let mut next_temp = self.next_temp.borrow_mut();
temp = *next_temp;
self.ntemps += 1;
new_next_temp = match next_temp.checked_add(1) {
Some(reg) => reg,
None => return Err(Error::OutOfRegisters(RegisterScope::Temp)),
};
*next_temp = new_next_temp;
}
{
let mut count_temps = self.count_temps.borrow_mut();
*count_temps = max(*count_temps, new_next_temp);
}
match u8::try_from(usize::from(self.nlocals) + usize::from(temp)) {
Ok(reg) => {
Ok(Register::local(reg).map_err(|_| Error::OutOfRegisters(RegisterScope::Temp))?)
}
Err(_) => Err(Error::OutOfRegisters(RegisterScope::Temp)),
}
}
}
#[cfg(test)]
mod tests {
use super::*;
use crate::CallableMetadataBuilder;
#[test]
fn test_symbol_key_case_insensitive() {
assert_eq!(SymbolKey::from("foo"), SymbolKey::from("FOO"));
assert_eq!(SymbolKey::from("Foo"), SymbolKey::from("fOo"));
}
#[test]
fn test_symbol_key_display() {
assert_eq!("FOO", format!("{}", SymbolKey::from("foo")));
}
#[test]
fn test_global_put_and_get() -> Result<()> {
let upcalls = HashMap::default();
let mut global = GlobalSymtable::new(upcalls);
let reg =
global.put_global(SymbolKey::from("x"), SymbolPrototype::Scalar(ExprType::Integer))?;
assert_eq!(Register::global(0).unwrap(), reg);
let reg =
global.put_global(SymbolKey::from("y"), SymbolPrototype::Scalar(ExprType::Text))?;
assert_eq!(Register::global(1).unwrap(), reg);
let (reg, proto) = global.get_global(&VarRef::new("x", None))?;
assert_eq!(Register::global(0).unwrap(), reg);
assert_eq!(SymbolPrototype::Scalar(ExprType::Integer), proto);
let (reg, proto) = global.get_global(&VarRef::new("y", Some(ExprType::Text)))?;
assert_eq!(Register::global(1).unwrap(), reg);
assert_eq!(SymbolPrototype::Scalar(ExprType::Text), proto);
Ok(())
}
#[test]
fn test_global_get_case_insensitive() -> Result<()> {
let upcalls = HashMap::default();
let mut global = GlobalSymtable::new(upcalls);
global.put_global(SymbolKey::from("MyVar"), SymbolPrototype::Scalar(ExprType::Double))?;
let (reg, proto) = global.get_global(&VarRef::new("myvar", None))?;
assert_eq!(Register::global(0).unwrap(), reg);
assert_eq!(SymbolPrototype::Scalar(ExprType::Double), proto);
let (reg2, _) = global.get_global(&VarRef::new("MYVAR", None))?;
assert_eq!(reg, reg2);
Ok(())
}
#[test]
fn test_global_get_incompatible_type() {
let upcalls = HashMap::default();
let mut global = GlobalSymtable::new(upcalls);
global
.put_global(SymbolKey::from("x"), SymbolPrototype::Scalar(ExprType::Integer))
.unwrap();
let err = global.get_global(&VarRef::new("x", Some(ExprType::Text))).unwrap_err();
assert_eq!("Incompatible type annotation in x$ reference", err.to_string());
}
#[test]
fn test_global_get_undefined() {
let upcalls = HashMap::default();
let global = GlobalSymtable::new(upcalls);
let err = global.get_global(&VarRef::new("x", None)).unwrap_err();
assert_eq!("Undefined global symbol x", err.to_string());
}
#[test]
fn test_local_put_and_get() -> Result<()> {
let upcalls = HashMap::default();
let mut global = GlobalSymtable::new(upcalls);
let mut local = global.enter_scope();
let reg =
local.put_local(SymbolKey::from("a"), SymbolPrototype::Scalar(ExprType::Boolean))?;
assert_eq!(Register::local(0).unwrap(), reg);
let reg =
local.put_local(SymbolKey::from("b"), SymbolPrototype::Scalar(ExprType::Double))?;
assert_eq!(Register::local(1).unwrap(), reg);
let (reg, proto) = local.get_local_or_global(&VarRef::new("a", None))?;
assert_eq!(Register::local(0).unwrap(), reg);
assert_eq!(SymbolPrototype::Scalar(ExprType::Boolean), proto);
Ok(())
}
#[test]
fn test_local_shadows_global() -> Result<()> {
let upcalls = HashMap::default();
let mut global = GlobalSymtable::new(upcalls);
global.put_global(SymbolKey::from("x"), SymbolPrototype::Scalar(ExprType::Integer))?;
let mut local = global.enter_scope();
local.put_local(SymbolKey::from("x"), SymbolPrototype::Scalar(ExprType::Text))?;
let (reg, proto) = local.get_local_or_global(&VarRef::new("x", None))?;
assert_eq!(Register::local(0).unwrap(), reg);
assert_eq!(SymbolPrototype::Scalar(ExprType::Text), proto);
Ok(())
}
#[test]
fn test_local_falls_through_to_global() -> Result<()> {
let upcalls = HashMap::default();
let mut global = GlobalSymtable::new(upcalls);
global.put_global(SymbolKey::from("g"), SymbolPrototype::Scalar(ExprType::Integer))?;
let local = global.enter_scope();
let (reg, proto) = local.get_local_or_global(&VarRef::new("g", None))?;
assert_eq!(Register::global(0).unwrap(), reg);
assert_eq!(SymbolPrototype::Scalar(ExprType::Integer), proto);
Ok(())
}
#[test]
fn test_local_get_undefined() {
let upcalls = HashMap::default();
let mut global = GlobalSymtable::new(upcalls);
let local = global.enter_scope();
let err = local.get_local_or_global(&VarRef::new("nope", None)).unwrap_err();
assert_eq!("Undefined global symbol nope", err.to_string());
}
#[test]
fn test_local_put_global_through_local() -> Result<()> {
let upcalls = HashMap::default();
let mut global = GlobalSymtable::new(upcalls);
let mut local = global.enter_scope();
let reg =
local.put_global(SymbolKey::from("g"), SymbolPrototype::Scalar(ExprType::Integer))?;
assert_eq!(Register::global(0).unwrap(), reg);
let (reg, proto) = local.get_local_or_global(&VarRef::new("g", None))?;
assert_eq!(Register::global(0).unwrap(), reg);
assert_eq!(SymbolPrototype::Scalar(ExprType::Integer), proto);
Ok(())
}
#[test]
fn test_fixup_local_type() -> Result<()> {
let upcalls = HashMap::default();
let mut global = GlobalSymtable::new(upcalls);
let mut local = global.enter_scope();
local.put_local(SymbolKey::from("x"), SymbolPrototype::Scalar(ExprType::Integer))?;
local.fixup_local_type(&VarRef::new("x", None), ExprType::Double)?;
let (_, proto) = local.get_local_or_global(&VarRef::new("x", None))?;
assert_eq!(SymbolPrototype::Scalar(ExprType::Double), proto);
Ok(())
}
#[test]
fn test_fixup_local_type_undefined() {
let upcalls = HashMap::default();
let mut global = GlobalSymtable::new(upcalls);
let mut local = global.enter_scope();
let err =
local.fixup_local_type(&VarRef::new("nope", None), ExprType::Integer).unwrap_err();
assert_eq!("Undefined local symbol nope", err.to_string());
}
#[test]
fn test_define_and_get_user_callable() -> Result<()> {
let upcalls = HashMap::default();
let mut global = GlobalSymtable::new(upcalls);
let md = CallableMetadataBuilder::new("MY_FUNC")
.with_return_type(ExprType::Integer)
.test_build();
global.declare_user_callable(&VarRef::new("my_func", None), md)?;
let found = global.get_callable(&SymbolKey::from("my_func"));
assert!(found.is_some());
assert_eq!("MY_FUNC", found.unwrap().name());
Ok(())
}
#[test]
fn test_define_user_callable_already_defined_but_is_compatible() -> Result<()> {
let upcalls = HashMap::default();
let mut global = GlobalSymtable::new(upcalls);
let md = CallableMetadataBuilder::new("DUP").test_build();
global.declare_user_callable(&VarRef::new("dup", None), md)?;
let md2 = CallableMetadataBuilder::new("DUP").test_build();
global.declare_user_callable(&VarRef::new("dup", None), md2)?;
Ok(())
}
#[test]
fn test_define_user_callable_already_defined_but_is_incompatible() -> Result<()> {
let upcalls = HashMap::default();
let mut global = GlobalSymtable::new(upcalls);
let md = CallableMetadataBuilder::new("DUP").test_build();
global.declare_user_callable(&VarRef::new("dup", None), md)?;
let md2 =
CallableMetadataBuilder::new("DUP").with_return_type(ExprType::Integer).test_build();
let err = global.declare_user_callable(&VarRef::new("dup", None), md2).unwrap_err();
assert_eq!("Cannot redefine dup", err.to_string());
Ok(())
}
#[test]
fn test_define_user_callable_via_local() -> Result<()> {
let upcalls = HashMap::default();
let mut global = GlobalSymtable::new(upcalls);
let mut local = global.enter_scope();
let md = CallableMetadataBuilder::new("SUB1").test_build();
local.declare_user_callable(&VarRef::new("sub1", None), md)?;
let found = local.get_callable(&SymbolKey::from("sub1"));
assert!(found.is_some());
Ok(())
}
#[test]
fn test_get_callable_upcall() {
let key = SymbolKey::from("BUILTIN");
let md = CallableMetadataBuilder::new("BUILTIN").test_build();
let mut upcalls_map = HashMap::new();
upcalls_map.insert(key, md);
let global = GlobalSymtable::new(upcalls_map);
let found = global.get_callable(&SymbolKey::from("builtin"));
assert!(found.is_some());
assert_eq!("BUILTIN", found.unwrap().name());
}
#[test]
fn test_user_callable_shadows_upcall() {
let key = SymbolKey::from("SHARED");
let builtin_md =
CallableMetadataBuilder::new("SHARED").with_return_type(ExprType::Boolean).test_build();
let mut upcalls_map = HashMap::new();
upcalls_map.insert(key, builtin_md);
let mut global = GlobalSymtable::new(upcalls_map);
let user_md =
CallableMetadataBuilder::new("SHARED").with_return_type(ExprType::Integer).test_build();
global.declare_user_callable(&VarRef::new("shared", None), user_md).unwrap();
let found = global.get_callable(&SymbolKey::from("shared")).unwrap();
assert_eq!(Some(ExprType::Integer), found.return_type());
}
#[test]
fn test_get_callable_not_found() {
let upcalls = HashMap::default();
let global = GlobalSymtable::new(upcalls);
assert!(global.get_callable(&SymbolKey::from("nope")).is_none());
}
#[test]
fn test_temp_scope_first() -> Result<()> {
let upcalls = HashMap::default();
let mut global = GlobalSymtable::new(upcalls);
let mut local = global.enter_scope();
local.put_local(SymbolKey::from("a"), SymbolPrototype::Scalar(ExprType::Integer))?;
local.put_local(SymbolKey::from("b"), SymbolPrototype::Scalar(ExprType::Integer))?;
{
let temp = local.frozen();
let mut scope = temp.temp_scope();
assert_eq!(Register::local(2).unwrap(), scope.first()?);
}
Ok(())
}
#[test]
fn test_temp_scope_first_no_locals() -> Result<()> {
let upcalls = HashMap::default();
let mut global = GlobalSymtable::new(upcalls);
let mut local = global.enter_scope();
{
let temp = local.frozen();
let mut scope = temp.temp_scope();
assert_eq!(Register::local(0).unwrap(), scope.first()?);
}
Ok(())
}
#[test]
fn test_temp_scope_first_with_outer_allocation() -> Result<()> {
let upcalls = HashMap::default();
let mut global = GlobalSymtable::new(upcalls);
let mut local = global.enter_scope();
local.put_local(SymbolKey::from("a"), SymbolPrototype::Scalar(ExprType::Integer))?;
{
let temp = local.frozen();
let mut outer = temp.temp_scope();
assert_eq!(Register::local(1).unwrap(), outer.alloc()?);
let mut inner = temp.temp_scope();
assert_eq!(Register::local(2).unwrap(), inner.first()?);
}
Ok(())
}
#[test]
fn test_temp_scope() -> Result<()> {
let upcalls = HashMap::default();
let mut global = GlobalSymtable::new(upcalls);
let mut local = global.enter_scope();
assert_eq!(
Register::local(0).unwrap(),
local.put_local(SymbolKey::from("foo"), SymbolPrototype::Scalar(ExprType::Integer))?
);
{
let temp = local.frozen();
{
let mut scope = temp.temp_scope();
assert_eq!(Register::local(1).unwrap(), scope.alloc()?);
{
let mut scope = temp.temp_scope();
assert_eq!(Register::local(2).unwrap(), scope.alloc()?);
assert_eq!(Register::local(3).unwrap(), scope.alloc()?);
assert_eq!(Register::local(4).unwrap(), scope.alloc()?);
}
{
let mut scope = temp.temp_scope();
assert_eq!(Register::local(2).unwrap(), scope.alloc()?);
assert_eq!(Register::local(3).unwrap(), scope.alloc()?);
}
assert_eq!(Register::local(2).unwrap(), scope.alloc()?);
}
}
{
let temp = local.frozen();
{
let mut scope = temp.temp_scope();
assert_eq!(Register::local(1).unwrap(), scope.alloc()?);
}
}
Ok(())
}
#[test]
fn test_with_reserved_temp_register_index() -> Result<()> {
let upcalls = HashMap::default();
let mut global = GlobalSymtable::new(upcalls);
let mut local = global.enter_scope();
local.put_local(SymbolKey::from("a"), SymbolPrototype::Scalar(ExprType::Integer))?;
local.put_local(SymbolKey::from("b"), SymbolPrototype::Scalar(ExprType::Integer))?;
local.with_reserved_temp(
|e| e,
|reg, _| {
assert_eq!(Register::local(2).unwrap(), reg);
Ok(())
},
)?;
Ok(())
}
#[test]
fn test_with_reserved_temp_shifts_temp_scope_base() -> Result<()> {
let upcalls = HashMap::default();
let mut global = GlobalSymtable::new(upcalls);
let mut local = global.enter_scope();
local.put_local(SymbolKey::from("a"), SymbolPrototype::Scalar(ExprType::Integer))?;
local.with_reserved_temp(
|e| e,
|reserved, temp| {
assert_eq!(Register::local(1).unwrap(), reserved);
let mut scope = temp.temp_scope();
assert_eq!(Register::local(2).unwrap(), scope.alloc()?);
Ok(())
},
)?;
Ok(())
}
#[test]
fn test_with_reserved_temp_released_after_error() {
let upcalls = HashMap::default();
let mut global = GlobalSymtable::new(upcalls);
let mut local = global.enter_scope();
let err = local
.with_reserved_temp(
|e| e,
|_, _| Err::<(), Error>(Error::OutOfRegisters(RegisterScope::Temp)),
)
.unwrap_err();
assert_eq!("Out of temp registers", err.to_string());
local
.with_reserved_temp(
|e| e,
|reg, _| {
assert_eq!(Register::local(0).unwrap(), reg);
Ok(())
},
)
.unwrap();
}
#[test]
fn test_temp_scope_lookup_vars() -> Result<()> {
let upcalls = HashMap::default();
let mut global = GlobalSymtable::new(upcalls);
global.put_global(SymbolKey::from("g"), SymbolPrototype::Scalar(ExprType::Integer))?;
let mut local = global.enter_scope();
local.put_local(SymbolKey::from("l"), SymbolPrototype::Scalar(ExprType::Text))?;
{
let temp = local.frozen();
let (reg, proto) = temp.get_local_or_global(&VarRef::new("l", None))?;
assert_eq!(Register::local(0).unwrap(), reg);
assert_eq!(SymbolPrototype::Scalar(ExprType::Text), proto);
let (reg, proto) = temp.get_local_or_global(&VarRef::new("g", None))?;
assert_eq!(Register::global(0).unwrap(), reg);
assert_eq!(SymbolPrototype::Scalar(ExprType::Integer), proto);
}
Ok(())
}
#[test]
fn test_temp_scope_lookup_callable() -> Result<()> {
let upcalls = HashMap::default();
let mut global = GlobalSymtable::new(upcalls);
let md = CallableMetadataBuilder::new("FOO").test_build();
global.declare_user_callable(&VarRef::new("foo", None), md)?;
let mut local = global.enter_scope();
{
let temp = local.frozen();
assert!(temp.get_callable(&SymbolKey::from("foo")).is_some());
assert!(temp.get_callable(&SymbolKey::from("nope")).is_none());
}
Ok(())
}
#[test]
fn test_multiple_scopes_independent_locals() -> Result<()> {
let upcalls = HashMap::default();
let mut global = GlobalSymtable::new(upcalls);
{
let mut local = global.enter_scope();
local.put_local(SymbolKey::from("x"), SymbolPrototype::Scalar(ExprType::Integer))?;
}
{
let mut local = global.enter_scope();
let err = local.get_local_or_global(&VarRef::new("x", None)).unwrap_err();
assert_eq!("Undefined global symbol x", err.to_string());
let reg =
local.put_local(SymbolKey::from("y"), SymbolPrototype::Scalar(ExprType::Double))?;
assert_eq!(Register::local(0).unwrap(), reg);
}
Ok(())
}
#[test]
fn test_global_put_and_get_array() -> Result<()> {
let upcalls = HashMap::default();
let mut global = GlobalSymtable::new(upcalls);
let reg = global.put_global(
SymbolKey::from("arr"),
SymbolPrototype::Array(ArrayInfo { subtype: ExprType::Integer, ndims: 2 }),
)?;
assert_eq!(Register::global(0).unwrap(), reg);
let (got_reg, proto) = global.get_global(&VarRef::new("arr", None)).unwrap();
assert_eq!(Register::global(0).unwrap(), got_reg);
let SymbolPrototype::Array(info) = proto else { panic!("Expected Array prototype") };
assert_eq!(ExprType::Integer, info.subtype);
assert_eq!(2, info.ndims);
Ok(())
}
#[test]
fn test_local_put_and_get_array() -> Result<()> {
let upcalls = HashMap::default();
let mut global = GlobalSymtable::new(upcalls);
let mut local = global.enter_scope();
let reg = local.put_local(
SymbolKey::from("arr"),
SymbolPrototype::Array(ArrayInfo { subtype: ExprType::Double, ndims: 1 }),
)?;
assert_eq!(Register::local(0).unwrap(), reg);
let (got_reg, proto) = local.get_local_or_global(&VarRef::new("arr", None)).unwrap();
assert_eq!(Register::local(0).unwrap(), got_reg);
let SymbolPrototype::Array(info) = proto else { panic!("Expected Array prototype") };
assert_eq!(ExprType::Double, info.subtype);
assert_eq!(1, info.ndims);
Ok(())
}
}