Skip to main content

endbasic_core/compiler/
syms.rs

1// EndBASIC
2// Copyright 2026 Julio Merino
3//
4// This program is free software: you can redistribute it and/or modify
5// it under the terms of the GNU Affero General Public License as published by
6// the Free Software Foundation, either version 3 of the License, or
7// (at your option) any later version.
8//
9// This program is distributed in the hope that it will be useful,
10// but WITHOUT ANY WARRANTY; without even the implied warranty of
11// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
12// GNU Affero General Public License for more details.
13//
14// You should have received a copy of the GNU Affero General Public License
15// along with this program.  If not, see <https://www.gnu.org/licenses/>.
16
17//! Symbol table for EndBASIC compilation.
18
19use crate::ast::{ExprType, VarRef};
20use crate::bytecode::{Register, RegisterScope};
21use crate::compiler::ids::HashMapWithIds;
22use crate::{CallableMetadata, bytecode};
23use std::cell::Cell;
24use std::cell::RefCell;
25use std::cmp::max;
26use std::collections::HashMap;
27use std::convert::TryFrom;
28use std::fmt;
29use std::rc::Rc;
30
31/// Information about an array tracked in the symbol table.
32#[derive(Clone, Copy, Debug, PartialEq)]
33pub(super) struct ArrayInfo {
34    /// Element type of the array.
35    pub(super) subtype: ExprType,
36
37    /// Number of dimensions.
38    pub(super) ndims: usize,
39}
40
41/// Prototype for a variable-like symbol (scalar or array).
42#[derive(Clone, Copy, Debug, PartialEq)]
43pub(super) enum SymbolPrototype {
44    /// An array with the given element type and number of dimensions.
45    Array(ArrayInfo),
46
47    /// A scalar variable of the given type.
48    Scalar(ExprType),
49}
50
51/// Errors related to symbols handling.
52#[derive(Debug, thiserror::Error)]
53#[allow(missing_docs)] // The error messages and names are good enough.
54pub(super) enum Error {
55    #[error("Cannot redefine {0}")]
56    AlreadyDefined(VarRef),
57
58    #[error("Incompatible type annotation in {0} reference")]
59    IncompatibleTypeAnnotationInReference(VarRef),
60
61    #[error("Out of {0} registers")]
62    OutOfRegisters(RegisterScope),
63
64    #[error("Undefined {1} symbol {0}")]
65    UndefinedSymbol(VarRef, RegisterScope),
66}
67
68/// Result type for symbol table operations.
69type Result<T> = std::result::Result<T, Error>;
70
71/// The key of a symbol in the symbols table.
72///
73/// The key is stored in a canonicalized form (uppercase) to make all lookups case-insensitive.
74#[derive(Clone, Debug, Hash, Eq, Ord, PartialEq, PartialOrd)]
75pub struct SymbolKey(String);
76
77impl<R: AsRef<str>> From<R> for SymbolKey {
78    fn from(value: R) -> Self {
79        Self(value.as_ref().to_ascii_uppercase())
80    }
81}
82
83impl fmt::Display for SymbolKey {
84    fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
85        write!(f, "{}", self.0)
86    }
87}
88
89/// Gets the register and prototype of a local or global variable if it already exists.
90fn get_var<MKR>(
91    vref: &VarRef,
92    table: &HashMapWithIds<SymbolKey, SymbolPrototype, u8>,
93    make_register: MKR,
94    scope: RegisterScope,
95) -> Result<(Register, SymbolPrototype)>
96where
97    MKR: FnOnce(u8) -> std::result::Result<Register, bytecode::OutOfRegistersError>,
98{
99    let key = SymbolKey::from(&vref.name);
100    match table.get(&key) {
101        Some((SymbolPrototype::Array(info), reg)) => {
102            if !vref.accepts(info.subtype) {
103                return Err(Error::IncompatibleTypeAnnotationInReference(vref.clone()));
104            }
105
106            let reg = make_register(reg).map_err(|_| Error::OutOfRegisters(scope))?;
107            Ok((reg, SymbolPrototype::Array(*info)))
108        }
109
110        Some((SymbolPrototype::Scalar(etype), reg)) => {
111            if !vref.accepts(*etype) {
112                return Err(Error::IncompatibleTypeAnnotationInReference(vref.clone()));
113            }
114
115            let reg = make_register(reg).map_err(|_| Error::OutOfRegisters(scope))?;
116            Ok((reg, SymbolPrototype::Scalar(*etype)))
117        }
118
119        None => Err(Error::UndefinedSymbol(vref.clone(), scope)),
120    }
121}
122
123/// Defines a new local or global variable (or array) and assigns a register to it.
124///
125/// Panics if the symbol already exists.
126fn put_var<MKR>(
127    key: SymbolKey,
128    proto: SymbolPrototype,
129    table: &mut HashMapWithIds<SymbolKey, SymbolPrototype, u8>,
130    make_register: MKR,
131    scope: RegisterScope,
132) -> Result<Register>
133where
134    MKR: FnOnce(u8) -> std::result::Result<Register, bytecode::OutOfRegistersError>,
135{
136    match table.insert(key, proto) {
137        Some((None, reg)) => Ok(make_register(reg).map_err(|_| Error::OutOfRegisters(scope))?),
138
139        Some((Some(_old_proto), _reg)) => {
140            unreachable!("Cannot redefine symbol; caller must check for presence first");
141        }
142
143        None => Err(Error::OutOfRegisters(scope)),
144    }
145}
146
147/// Representation of the symbol table for global symbols.
148///
149/// Globals are variables and callables that are visible from any scope.
150#[derive(Clone)]
151pub(crate) struct GlobalSymtable {
152    /// Map of global variable names to their prototypes and assigned registers.
153    globals: HashMapWithIds<SymbolKey, SymbolPrototype, u8>,
154
155    /// Reference to the built-in callable metadata provided by the runtime.
156    upcalls: HashMap<SymbolKey, Rc<CallableMetadata>>,
157
158    /// Map of user-defined callable names to their metadata.
159    user_callables: HashMap<SymbolKey, Rc<CallableMetadata>>,
160}
161
162impl GlobalSymtable {
163    /// Creates a new global symbol table that knows about the given `upcalls`.
164    pub(crate) fn new(upcalls: HashMap<SymbolKey, Rc<CallableMetadata>>) -> Self {
165        Self { globals: HashMapWithIds::default(), upcalls, user_callables: HashMap::default() }
166    }
167
168    /// Enters a new local scope.
169    pub(crate) fn enter_scope(&mut self) -> LocalSymtable<'_> {
170        LocalSymtable::new(self)
171    }
172
173    /// Gets a global symbol by its `vref`, returning its register and prototype.
174    pub(crate) fn get_global(&self, vref: &VarRef) -> Result<(Register, SymbolPrototype)> {
175        get_var(vref, &self.globals, Register::global, RegisterScope::Global)
176    }
177
178    /// Creates a new global symbol `key` with `proto`.
179    pub(crate) fn put_global(
180        &mut self,
181        key: SymbolKey,
182        proto: SymbolPrototype,
183    ) -> Result<Register> {
184        put_var(key, proto, &mut self.globals, Register::global, RegisterScope::Global)
185    }
186
187    /// Returns true if a global variable `key` is already defined.
188    pub(crate) fn contains_global(&self, key: &SymbolKey) -> bool {
189        self.globals.get(key).is_some()
190    }
191
192    /// Iterates over all global variables, yielding `(key, prototype, register_index)` tuples.
193    pub(crate) fn iter_globals(
194        &self,
195    ) -> impl Iterator<Item = (&SymbolKey, SymbolPrototype, u8)> + '_ {
196        self.globals.iter().map(|(k, v, i)| (k, *v, i))
197    }
198
199    /// Defines a new user-defined `vref` callable with `md` metadata.
200    pub(crate) fn declare_user_callable(
201        &mut self,
202        vref: &VarRef,
203        md: Rc<CallableMetadata>,
204    ) -> Result<()> {
205        let key = SymbolKey::from(&vref.name);
206        if self.globals.get(&key).is_some() {
207            return Err(Error::AlreadyDefined(vref.clone()));
208        }
209        if let Some(previous_md) = self.user_callables.insert(key.clone(), md.clone())
210            && previous_md != md
211        {
212            return Err(Error::AlreadyDefined(vref.clone()));
213        }
214        Ok(())
215    }
216
217    /// Gets a callable by its name `key`.
218    pub(crate) fn get_callable(&self, key: &SymbolKey) -> Option<Rc<CallableMetadata>> {
219        self.user_callables.get(key).or(self.upcalls.get(key)).cloned()
220    }
221}
222
223#[derive(Clone, Default)]
224pub(crate) struct LocalSymtableSnapshot {
225    /// Map of local variable names to their prototypes and assigned registers.
226    locals: HashMapWithIds<SymbolKey, SymbolPrototype, u8>,
227
228    /// Maximum number of allocated temporary registers in all possible evaluation scopes created
229    /// by this local symtable.  This is used to determine the size of the scope for register
230    /// allocation purposes at runtime.
231    count_temps: u8,
232
233    /// Number of reserved temporary registers that are active outside of `TempScope`.
234    active_temps: Rc<Cell<u8>>,
235}
236
237/// Representation of the symbol table for a local scope.
238///
239/// A local scope can see all global symbols and defines its own symbols, which can shadow the
240/// global ones.
241pub(crate) struct LocalSymtable<'a> {
242    /// Reference to the parent global symbol table.
243    symtable: &'a mut GlobalSymtable,
244
245    /// Map of local variable names to their prototypes and assigned registers.
246    locals: HashMapWithIds<SymbolKey, SymbolPrototype, u8>,
247
248    /// Maximum number of allocated temporary registers in all possible evaluation scopes created
249    /// by this local symtable.  This is used to determine the size of the scope for register
250    /// allocation purposes at runtime.
251    count_temps: u8,
252
253    /// Number of reserved temporary registers that are active outside of `TempScope`.
254    active_temps: Rc<Cell<u8>>,
255}
256
257impl<'a> LocalSymtable<'a> {
258    /// Creates a new local symbol table within the context of a global `symtable`.
259    fn new(symtable: &'a mut GlobalSymtable) -> Self {
260        Self {
261            symtable,
262            locals: HashMapWithIds::default(),
263            count_temps: 0,
264            active_temps: Rc::from(Cell::new(0)),
265        }
266    }
267
268    /// Preserves the state of this local symbol table, detached from the global symbol table
269    /// it belongs to.
270    pub(crate) fn save(self) -> LocalSymtableSnapshot {
271        LocalSymtableSnapshot {
272            locals: self.locals,
273            count_temps: self.count_temps,
274            active_temps: self.active_temps,
275        }
276    }
277
278    /// Reattaches a previous local symbol table content to a global symbol table so that it
279    /// can be used again for compilation.
280    pub(crate) fn restore(
281        symtable: &'a mut GlobalSymtable,
282        snapshot: LocalSymtableSnapshot,
283    ) -> Self {
284        Self {
285            symtable,
286            locals: snapshot.locals,
287            count_temps: snapshot.count_temps,
288            active_temps: snapshot.active_temps,
289        }
290    }
291
292    /// Obtains mutable access to the parent global symtable.
293    pub(crate) fn global(&mut self) -> &mut GlobalSymtable {
294        self.symtable
295    }
296
297    /// Declares a new user-defined `vref` callable with `md` metadata.
298    pub(crate) fn declare_user_callable(
299        &mut self,
300        vref: &VarRef,
301        md: Rc<CallableMetadata>,
302    ) -> Result<()> {
303        self.symtable.declare_user_callable(vref, md)
304    }
305
306    /// Freezes this table to get a `TempSymtable` that can be used to compile expressions.
307    pub(crate) fn frozen(&mut self) -> TempSymtable<'_, 'a> {
308        TempSymtable::new(self)
309    }
310
311    /// Reserves one temporary register for the duration of `f`.
312    pub(crate) fn with_reserved_temp<T, E, ME, F>(
313        &mut self,
314        map_error: ME,
315        f: F,
316    ) -> std::result::Result<T, E>
317    where
318        ME: Fn(Error) -> E,
319        F: FnOnce(Register, &mut TempSymtable<'_, 'a>) -> std::result::Result<T, E>,
320    {
321        struct TempReservationGuard {
322            active_temps: Rc<Cell<u8>>,
323        }
324
325        impl Drop for TempReservationGuard {
326            fn drop(&mut self) {
327                let active_temps = self.active_temps.get();
328                debug_assert!(active_temps > 0);
329                self.active_temps.set(active_temps - 1);
330            }
331        }
332
333        let nlocals = u8::try_from(self.locals.len())
334            .map_err(|_| map_error(Error::OutOfRegisters(RegisterScope::Local)))?;
335        let first_temp = self.active_temps.get();
336        let new_active_temps = first_temp
337            .checked_add(1)
338            .ok_or(map_error(Error::OutOfRegisters(RegisterScope::Temp)))?;
339        self.active_temps.set(new_active_temps);
340        self.count_temps = max(self.count_temps, new_active_temps);
341        let _guard = TempReservationGuard { active_temps: self.active_temps.clone() };
342
343        let reg_idx = u8::try_from(usize::from(nlocals) + usize::from(first_temp))
344            .map_err(|_| map_error(Error::OutOfRegisters(RegisterScope::Temp)))?;
345        let reg = Register::local(reg_idx)
346            .map_err(|_| map_error(Error::OutOfRegisters(RegisterScope::Temp)))?;
347
348        let mut temp = self.frozen();
349        f(reg, &mut temp)
350    }
351
352    /// Creates a new global symbol `key` with `proto` via the parent global symbol table.
353    pub(crate) fn put_global(
354        &mut self,
355        key: SymbolKey,
356        proto: SymbolPrototype,
357    ) -> Result<Register> {
358        self.symtable.put_global(key, proto)
359    }
360
361    /// Gets a symbol by its `vref`, looking for it in the local and global scopes.
362    pub(crate) fn get_local_or_global(&self, vref: &VarRef) -> Result<(Register, SymbolPrototype)> {
363        match get_var(vref, &self.locals, Register::local, RegisterScope::Local) {
364            Ok(local) => Ok(local),
365            Err(Error::UndefinedSymbol(..)) => self.symtable.get_global(vref),
366            Err(e) => Err(e),
367        }
368    }
369
370    /// Gets a callable by its name `key`.
371    pub(crate) fn get_callable(&self, key: &SymbolKey) -> Option<Rc<CallableMetadata>> {
372        self.symtable.get_callable(key)
373    }
374
375    /// Creates a new local symbol `key` with `proto`.
376    pub(crate) fn put_local(&mut self, key: SymbolKey, proto: SymbolPrototype) -> Result<Register> {
377        put_var(key, proto, &mut self.locals, Register::local, RegisterScope::Local)
378    }
379
380    /// Returns true if a local variable `key` is already defined.
381    pub(crate) fn contains_local(&self, key: &SymbolKey) -> bool {
382        self.locals.get(key).is_some()
383    }
384
385    /// Returns true if a global variable `key` is already defined.
386    pub(crate) fn contains_global(&self, key: &SymbolKey) -> bool {
387        self.symtable.contains_global(key)
388    }
389
390    /// Iterates over all global variables, yielding `(key, prototype, register_index)` tuples.
391    pub(crate) fn iter_globals(
392        &self,
393    ) -> impl Iterator<Item = (SymbolKey, SymbolPrototype, u8)> + '_ {
394        self.symtable.iter_globals().map(|(k, v, i)| (k.clone(), v, i))
395    }
396
397    /// Iterates over all local variables, yielding `(key, prototype, register_index)` tuples.
398    pub(crate) fn iter_locals(
399        &self,
400    ) -> impl Iterator<Item = (SymbolKey, SymbolPrototype, u8)> + '_ {
401        self.locals.iter().map(|(k, v, i)| (k.clone(), *v, i))
402    }
403
404    /// Changes the type of an existing local variable `vref` to `new_etype`.
405    ///
406    /// This is used for type inference on first assignment.
407    pub(crate) fn fixup_local_type(&mut self, vref: &VarRef, new_etype: ExprType) -> Result<()> {
408        let key = SymbolKey::from(&vref.name);
409        // TODO: Verify reference type.
410        match self.locals.get_mut(&key) {
411            Some((SymbolPrototype::Array(_), _)) | None => {
412                Err(Error::UndefinedSymbol(vref.clone(), RegisterScope::Local))
413            }
414
415            Some((SymbolPrototype::Scalar(etype), _reg)) => {
416                *etype = new_etype;
417                Ok(())
418            }
419        }
420    }
421}
422
423/// A read-only view into a `SymTable` that allows allocating temporary registers.
424///
425/// This layer on top of `LocalSymtable` may seem redundant because all of the temporary
426/// register manipulation happens in `TempScope`, but it is necessary to have this layer
427/// to forbid mutations to local variables.  We need to be able to pass a `TempSymtable`
428/// across recursive function calls (for expression evaluation), but at the same time we
429/// need each call site to have its own `TempScope` for temporary register cleanup.
430pub(crate) struct TempSymtable<'temp, 'local> {
431    /// Reference to the underlying local symbol table.
432    symtable: &'temp mut LocalSymtable<'local>,
433
434    /// Number of temporary registers that were already reserved on creation.
435    base_temp: u8,
436
437    /// Index of the next temporary register to allocate.
438    next_temp: Rc<RefCell<u8>>,
439
440    /// Maximum number of allocated temporary registers in a given evaluation (recursion) path.
441    count_temps: Rc<RefCell<u8>>,
442}
443
444impl<'temp, 'local> Drop for TempSymtable<'temp, 'local> {
445    fn drop(&mut self) {
446        debug_assert_eq!(self.base_temp, *self.next_temp.borrow(), "Unbalanced temp drops");
447        self.symtable.count_temps = max(self.symtable.count_temps, *self.count_temps.borrow());
448    }
449}
450
451impl<'temp, 'local> TempSymtable<'temp, 'local> {
452    /// Creates a new temporary symbol table from a `local` table.
453    fn new(symtable: &'temp mut LocalSymtable<'local>) -> Self {
454        let base_temp = symtable.active_temps.get();
455        Self {
456            symtable,
457            base_temp,
458            next_temp: Rc::from(RefCell::from(base_temp)),
459            count_temps: Rc::from(RefCell::from(base_temp)),
460        }
461    }
462
463    /// Gets a symbol by its `vref`, looking for it in the local and global scopes.
464    pub(crate) fn get_local_or_global(&self, vref: &VarRef) -> Result<(Register, SymbolPrototype)> {
465        self.symtable.get_local_or_global(vref)
466    }
467
468    /// Gets a callable by its name `key`.
469    pub(crate) fn get_callable(&self, key: &SymbolKey) -> Option<Rc<CallableMetadata>> {
470        self.symtable.get_callable(key)
471    }
472
473    /// Enters a new temporary scope.
474    pub(crate) fn temp_scope(&self) -> TempScope {
475        let nlocals = u8::try_from(self.symtable.locals.len())
476            .expect("Cannot have allocated more locals than u8");
477        TempScope {
478            base_temp: *self.next_temp.borrow(),
479            nlocals,
480            ntemps: 0,
481            next_temp: self.next_temp.clone(),
482            count_temps: self.count_temps.clone(),
483        }
484    }
485}
486
487/// A scope for temporary registers.
488///
489/// Temporaries are allocated on demand and are cleaned up when the scope is dropped.
490pub(crate) struct TempScope {
491    /// Number of temporary registers that were already active on scope creation.
492    base_temp: u8,
493
494    /// Number of local variables in the enclosing scope, used as the base for temporary registers.
495    nlocals: u8,
496
497    /// Number of temporary registers allocated by this scope.
498    ntemps: u8,
499
500    /// Shared counter for the next temporary register index to allocate.
501    next_temp: Rc<RefCell<u8>>,
502
503    /// Shared counter tracking the maximum number of temporary registers used.
504    count_temps: Rc<RefCell<u8>>,
505}
506
507impl Drop for TempScope {
508    fn drop(&mut self) {
509        let mut next_temp = self.next_temp.borrow_mut();
510        debug_assert!(*next_temp >= self.ntemps);
511        *next_temp -= self.ntemps;
512    }
513}
514
515impl TempScope {
516    /// Returns the first register available for this scope.
517    pub(crate) fn first(&mut self) -> Result<Register> {
518        let reg = u8::try_from(usize::from(self.nlocals) + usize::from(self.base_temp))
519            .map_err(|_| Error::OutOfRegisters(RegisterScope::Temp))?;
520        Register::local(reg).map_err(|_| Error::OutOfRegisters(RegisterScope::Temp))
521    }
522
523    /// Allocates a new temporary register.
524    pub(crate) fn alloc(&mut self) -> Result<Register> {
525        let temp;
526        let new_next_temp;
527        {
528            let mut next_temp = self.next_temp.borrow_mut();
529            temp = *next_temp;
530            self.ntemps += 1;
531            new_next_temp = match next_temp.checked_add(1) {
532                Some(reg) => reg,
533                None => return Err(Error::OutOfRegisters(RegisterScope::Temp)),
534            };
535            *next_temp = new_next_temp;
536        }
537
538        {
539            let mut count_temps = self.count_temps.borrow_mut();
540            *count_temps = max(*count_temps, new_next_temp);
541        }
542
543        match u8::try_from(usize::from(self.nlocals) + usize::from(temp)) {
544            Ok(reg) => {
545                Ok(Register::local(reg).map_err(|_| Error::OutOfRegisters(RegisterScope::Temp))?)
546            }
547            Err(_) => Err(Error::OutOfRegisters(RegisterScope::Temp)),
548        }
549    }
550}
551
552#[cfg(test)]
553mod tests {
554    use super::*;
555    use crate::CallableMetadataBuilder;
556
557    #[test]
558    fn test_symbol_key_case_insensitive() {
559        assert_eq!(SymbolKey::from("foo"), SymbolKey::from("FOO"));
560        assert_eq!(SymbolKey::from("Foo"), SymbolKey::from("fOo"));
561    }
562
563    #[test]
564    fn test_symbol_key_display() {
565        assert_eq!("FOO", format!("{}", SymbolKey::from("foo")));
566    }
567
568    #[test]
569    fn test_global_put_and_get() -> Result<()> {
570        let upcalls = HashMap::default();
571        let mut global = GlobalSymtable::new(upcalls);
572
573        let reg =
574            global.put_global(SymbolKey::from("x"), SymbolPrototype::Scalar(ExprType::Integer))?;
575        assert_eq!(Register::global(0).unwrap(), reg);
576
577        let reg =
578            global.put_global(SymbolKey::from("y"), SymbolPrototype::Scalar(ExprType::Text))?;
579        assert_eq!(Register::global(1).unwrap(), reg);
580
581        // Lookup with untyped ref succeeds.
582        let (reg, proto) = global.get_global(&VarRef::new("x", None))?;
583        assert_eq!(Register::global(0).unwrap(), reg);
584        assert_eq!(SymbolPrototype::Scalar(ExprType::Integer), proto);
585
586        // Lookup with matching typed ref succeeds.
587        let (reg, proto) = global.get_global(&VarRef::new("y", Some(ExprType::Text)))?;
588        assert_eq!(Register::global(1).unwrap(), reg);
589        assert_eq!(SymbolPrototype::Scalar(ExprType::Text), proto);
590
591        Ok(())
592    }
593
594    #[test]
595    fn test_global_get_case_insensitive() -> Result<()> {
596        let upcalls = HashMap::default();
597        let mut global = GlobalSymtable::new(upcalls);
598        global.put_global(SymbolKey::from("MyVar"), SymbolPrototype::Scalar(ExprType::Double))?;
599
600        let (reg, proto) = global.get_global(&VarRef::new("myvar", None))?;
601        assert_eq!(Register::global(0).unwrap(), reg);
602        assert_eq!(SymbolPrototype::Scalar(ExprType::Double), proto);
603
604        let (reg2, _) = global.get_global(&VarRef::new("MYVAR", None))?;
605        assert_eq!(reg, reg2);
606        Ok(())
607    }
608
609    #[test]
610    fn test_global_get_incompatible_type() {
611        let upcalls = HashMap::default();
612        let mut global = GlobalSymtable::new(upcalls);
613        global
614            .put_global(SymbolKey::from("x"), SymbolPrototype::Scalar(ExprType::Integer))
615            .unwrap();
616
617        let err = global.get_global(&VarRef::new("x", Some(ExprType::Text))).unwrap_err();
618        assert_eq!("Incompatible type annotation in x$ reference", err.to_string());
619    }
620
621    #[test]
622    fn test_global_get_undefined() {
623        let upcalls = HashMap::default();
624        let global = GlobalSymtable::new(upcalls);
625
626        let err = global.get_global(&VarRef::new("x", None)).unwrap_err();
627        assert_eq!("Undefined global symbol x", err.to_string());
628    }
629
630    #[test]
631    fn test_local_put_and_get() -> Result<()> {
632        let upcalls = HashMap::default();
633        let mut global = GlobalSymtable::new(upcalls);
634        let mut local = global.enter_scope();
635
636        let reg =
637            local.put_local(SymbolKey::from("a"), SymbolPrototype::Scalar(ExprType::Boolean))?;
638        assert_eq!(Register::local(0).unwrap(), reg);
639
640        let reg =
641            local.put_local(SymbolKey::from("b"), SymbolPrototype::Scalar(ExprType::Double))?;
642        assert_eq!(Register::local(1).unwrap(), reg);
643
644        let (reg, proto) = local.get_local_or_global(&VarRef::new("a", None))?;
645        assert_eq!(Register::local(0).unwrap(), reg);
646        assert_eq!(SymbolPrototype::Scalar(ExprType::Boolean), proto);
647
648        Ok(())
649    }
650
651    #[test]
652    fn test_local_shadows_global() -> Result<()> {
653        let upcalls = HashMap::default();
654        let mut global = GlobalSymtable::new(upcalls);
655        global.put_global(SymbolKey::from("x"), SymbolPrototype::Scalar(ExprType::Integer))?;
656
657        let mut local = global.enter_scope();
658        local.put_local(SymbolKey::from("x"), SymbolPrototype::Scalar(ExprType::Text))?;
659
660        let (reg, proto) = local.get_local_or_global(&VarRef::new("x", None))?;
661        assert_eq!(Register::local(0).unwrap(), reg);
662        assert_eq!(SymbolPrototype::Scalar(ExprType::Text), proto);
663
664        Ok(())
665    }
666
667    #[test]
668    fn test_local_falls_through_to_global() -> Result<()> {
669        let upcalls = HashMap::default();
670        let mut global = GlobalSymtable::new(upcalls);
671        global.put_global(SymbolKey::from("g"), SymbolPrototype::Scalar(ExprType::Integer))?;
672
673        let local = global.enter_scope();
674        let (reg, proto) = local.get_local_or_global(&VarRef::new("g", None))?;
675        assert_eq!(Register::global(0).unwrap(), reg);
676        assert_eq!(SymbolPrototype::Scalar(ExprType::Integer), proto);
677
678        Ok(())
679    }
680
681    #[test]
682    fn test_local_get_undefined() {
683        let upcalls = HashMap::default();
684        let mut global = GlobalSymtable::new(upcalls);
685        let local = global.enter_scope();
686
687        let err = local.get_local_or_global(&VarRef::new("nope", None)).unwrap_err();
688        assert_eq!("Undefined global symbol nope", err.to_string());
689    }
690
691    #[test]
692    fn test_local_put_global_through_local() -> Result<()> {
693        let upcalls = HashMap::default();
694        let mut global = GlobalSymtable::new(upcalls);
695        let mut local = global.enter_scope();
696
697        let reg =
698            local.put_global(SymbolKey::from("g"), SymbolPrototype::Scalar(ExprType::Integer))?;
699        assert_eq!(Register::global(0).unwrap(), reg);
700
701        // Should be visible from the local scope via fallthrough.
702        let (reg, proto) = local.get_local_or_global(&VarRef::new("g", None))?;
703        assert_eq!(Register::global(0).unwrap(), reg);
704        assert_eq!(SymbolPrototype::Scalar(ExprType::Integer), proto);
705
706        Ok(())
707    }
708
709    #[test]
710    fn test_fixup_local_type() -> Result<()> {
711        let upcalls = HashMap::default();
712        let mut global = GlobalSymtable::new(upcalls);
713        let mut local = global.enter_scope();
714
715        local.put_local(SymbolKey::from("x"), SymbolPrototype::Scalar(ExprType::Integer))?;
716        local.fixup_local_type(&VarRef::new("x", None), ExprType::Double)?;
717
718        let (_, proto) = local.get_local_or_global(&VarRef::new("x", None))?;
719        assert_eq!(SymbolPrototype::Scalar(ExprType::Double), proto);
720
721        Ok(())
722    }
723
724    #[test]
725    fn test_fixup_local_type_undefined() {
726        let upcalls = HashMap::default();
727        let mut global = GlobalSymtable::new(upcalls);
728        let mut local = global.enter_scope();
729
730        let err =
731            local.fixup_local_type(&VarRef::new("nope", None), ExprType::Integer).unwrap_err();
732        assert_eq!("Undefined local symbol nope", err.to_string());
733    }
734
735    #[test]
736    fn test_define_and_get_user_callable() -> Result<()> {
737        let upcalls = HashMap::default();
738        let mut global = GlobalSymtable::new(upcalls);
739
740        let md = CallableMetadataBuilder::new("MY_FUNC")
741            .with_return_type(ExprType::Integer)
742            .test_build();
743        global.declare_user_callable(&VarRef::new("my_func", None), md)?;
744
745        let found = global.get_callable(&SymbolKey::from("my_func"));
746        assert!(found.is_some());
747        assert_eq!("MY_FUNC", found.unwrap().name());
748
749        Ok(())
750    }
751
752    #[test]
753    fn test_define_user_callable_already_defined_but_is_compatible() -> Result<()> {
754        let upcalls = HashMap::default();
755        let mut global = GlobalSymtable::new(upcalls);
756
757        let md = CallableMetadataBuilder::new("DUP").test_build();
758        global.declare_user_callable(&VarRef::new("dup", None), md)?;
759
760        let md2 = CallableMetadataBuilder::new("DUP").test_build();
761        global.declare_user_callable(&VarRef::new("dup", None), md2)?;
762
763        Ok(())
764    }
765
766    #[test]
767    fn test_define_user_callable_already_defined_but_is_incompatible() -> Result<()> {
768        let upcalls = HashMap::default();
769        let mut global = GlobalSymtable::new(upcalls);
770
771        let md = CallableMetadataBuilder::new("DUP").test_build();
772        global.declare_user_callable(&VarRef::new("dup", None), md)?;
773
774        let md2 =
775            CallableMetadataBuilder::new("DUP").with_return_type(ExprType::Integer).test_build();
776        let err = global.declare_user_callable(&VarRef::new("dup", None), md2).unwrap_err();
777        assert_eq!("Cannot redefine dup", err.to_string());
778
779        Ok(())
780    }
781
782    #[test]
783    fn test_define_user_callable_via_local() -> Result<()> {
784        let upcalls = HashMap::default();
785        let mut global = GlobalSymtable::new(upcalls);
786        let mut local = global.enter_scope();
787
788        let md = CallableMetadataBuilder::new("SUB1").test_build();
789        local.declare_user_callable(&VarRef::new("sub1", None), md)?;
790
791        let found = local.get_callable(&SymbolKey::from("sub1"));
792        assert!(found.is_some());
793
794        Ok(())
795    }
796
797    #[test]
798    fn test_get_callable_upcall() {
799        let key = SymbolKey::from("BUILTIN");
800        let md = CallableMetadataBuilder::new("BUILTIN").test_build();
801        let mut upcalls_map = HashMap::new();
802        upcalls_map.insert(key, md);
803
804        let global = GlobalSymtable::new(upcalls_map);
805        let found = global.get_callable(&SymbolKey::from("builtin"));
806        assert!(found.is_some());
807        assert_eq!("BUILTIN", found.unwrap().name());
808    }
809
810    #[test]
811    fn test_user_callable_shadows_upcall() {
812        let key = SymbolKey::from("SHARED");
813        let builtin_md =
814            CallableMetadataBuilder::new("SHARED").with_return_type(ExprType::Boolean).test_build();
815        let mut upcalls_map = HashMap::new();
816        upcalls_map.insert(key, builtin_md);
817
818        let mut global = GlobalSymtable::new(upcalls_map);
819        let user_md =
820            CallableMetadataBuilder::new("SHARED").with_return_type(ExprType::Integer).test_build();
821        global.declare_user_callable(&VarRef::new("shared", None), user_md).unwrap();
822
823        let found = global.get_callable(&SymbolKey::from("shared")).unwrap();
824        assert_eq!(Some(ExprType::Integer), found.return_type());
825    }
826
827    #[test]
828    fn test_get_callable_not_found() {
829        let upcalls = HashMap::default();
830        let global = GlobalSymtable::new(upcalls);
831        assert!(global.get_callable(&SymbolKey::from("nope")).is_none());
832    }
833
834    #[test]
835    fn test_temp_scope_first() -> Result<()> {
836        let upcalls = HashMap::default();
837        let mut global = GlobalSymtable::new(upcalls);
838        let mut local = global.enter_scope();
839        local.put_local(SymbolKey::from("a"), SymbolPrototype::Scalar(ExprType::Integer))?;
840        local.put_local(SymbolKey::from("b"), SymbolPrototype::Scalar(ExprType::Integer))?;
841        {
842            let temp = local.frozen();
843            let mut scope = temp.temp_scope();
844            assert_eq!(Register::local(2).unwrap(), scope.first()?);
845        }
846        Ok(())
847    }
848
849    #[test]
850    fn test_temp_scope_first_no_locals() -> Result<()> {
851        let upcalls = HashMap::default();
852        let mut global = GlobalSymtable::new(upcalls);
853        let mut local = global.enter_scope();
854        {
855            let temp = local.frozen();
856            let mut scope = temp.temp_scope();
857            assert_eq!(Register::local(0).unwrap(), scope.first()?);
858        }
859        Ok(())
860    }
861
862    #[test]
863    fn test_temp_scope_first_with_outer_allocation() -> Result<()> {
864        let upcalls = HashMap::default();
865        let mut global = GlobalSymtable::new(upcalls);
866        let mut local = global.enter_scope();
867        local.put_local(SymbolKey::from("a"), SymbolPrototype::Scalar(ExprType::Integer))?;
868        {
869            let temp = local.frozen();
870            let mut outer = temp.temp_scope();
871            assert_eq!(Register::local(1).unwrap(), outer.alloc()?);
872
873            let mut inner = temp.temp_scope();
874            assert_eq!(Register::local(2).unwrap(), inner.first()?);
875        }
876        Ok(())
877    }
878
879    #[test]
880    fn test_temp_scope() -> Result<()> {
881        let upcalls = HashMap::default();
882        let mut global = GlobalSymtable::new(upcalls);
883        let mut local = global.enter_scope();
884        assert_eq!(
885            Register::local(0).unwrap(),
886            local.put_local(SymbolKey::from("foo"), SymbolPrototype::Scalar(ExprType::Integer))?
887        );
888        {
889            let temp = local.frozen();
890            {
891                let mut scope = temp.temp_scope();
892                assert_eq!(Register::local(1).unwrap(), scope.alloc()?);
893                {
894                    let mut scope = temp.temp_scope();
895                    assert_eq!(Register::local(2).unwrap(), scope.alloc()?);
896                    assert_eq!(Register::local(3).unwrap(), scope.alloc()?);
897                    assert_eq!(Register::local(4).unwrap(), scope.alloc()?);
898                }
899                {
900                    let mut scope = temp.temp_scope();
901                    assert_eq!(Register::local(2).unwrap(), scope.alloc()?);
902                    assert_eq!(Register::local(3).unwrap(), scope.alloc()?);
903                }
904                assert_eq!(Register::local(2).unwrap(), scope.alloc()?);
905            }
906        }
907        {
908            let temp = local.frozen();
909            {
910                let mut scope = temp.temp_scope();
911                assert_eq!(Register::local(1).unwrap(), scope.alloc()?);
912            }
913        }
914        Ok(())
915    }
916
917    #[test]
918    fn test_with_reserved_temp_register_index() -> Result<()> {
919        let upcalls = HashMap::default();
920        let mut global = GlobalSymtable::new(upcalls);
921        let mut local = global.enter_scope();
922        local.put_local(SymbolKey::from("a"), SymbolPrototype::Scalar(ExprType::Integer))?;
923        local.put_local(SymbolKey::from("b"), SymbolPrototype::Scalar(ExprType::Integer))?;
924
925        local.with_reserved_temp(
926            |e| e,
927            |reg, _| {
928                assert_eq!(Register::local(2).unwrap(), reg);
929                Ok(())
930            },
931        )?;
932
933        Ok(())
934    }
935
936    #[test]
937    fn test_with_reserved_temp_shifts_temp_scope_base() -> Result<()> {
938        let upcalls = HashMap::default();
939        let mut global = GlobalSymtable::new(upcalls);
940        let mut local = global.enter_scope();
941        local.put_local(SymbolKey::from("a"), SymbolPrototype::Scalar(ExprType::Integer))?;
942
943        local.with_reserved_temp(
944            |e| e,
945            |reserved, temp| {
946                assert_eq!(Register::local(1).unwrap(), reserved);
947                let mut scope = temp.temp_scope();
948                assert_eq!(Register::local(2).unwrap(), scope.alloc()?);
949                Ok(())
950            },
951        )?;
952
953        Ok(())
954    }
955
956    #[test]
957    fn test_with_reserved_temp_released_after_error() {
958        let upcalls = HashMap::default();
959        let mut global = GlobalSymtable::new(upcalls);
960        let mut local = global.enter_scope();
961
962        let err = local
963            .with_reserved_temp(
964                |e| e,
965                |_, _| Err::<(), Error>(Error::OutOfRegisters(RegisterScope::Temp)),
966            )
967            .unwrap_err();
968        assert_eq!("Out of temp registers", err.to_string());
969
970        local
971            .with_reserved_temp(
972                |e| e,
973                |reg, _| {
974                    assert_eq!(Register::local(0).unwrap(), reg);
975                    Ok(())
976                },
977            )
978            .unwrap();
979    }
980
981    #[test]
982    fn test_temp_scope_lookup_vars() -> Result<()> {
983        let upcalls = HashMap::default();
984        let mut global = GlobalSymtable::new(upcalls);
985        global.put_global(SymbolKey::from("g"), SymbolPrototype::Scalar(ExprType::Integer))?;
986        let mut local = global.enter_scope();
987        local.put_local(SymbolKey::from("l"), SymbolPrototype::Scalar(ExprType::Text))?;
988
989        {
990            let temp = local.frozen();
991
992            let (reg, proto) = temp.get_local_or_global(&VarRef::new("l", None))?;
993            assert_eq!(Register::local(0).unwrap(), reg);
994            assert_eq!(SymbolPrototype::Scalar(ExprType::Text), proto);
995
996            let (reg, proto) = temp.get_local_or_global(&VarRef::new("g", None))?;
997            assert_eq!(Register::global(0).unwrap(), reg);
998            assert_eq!(SymbolPrototype::Scalar(ExprType::Integer), proto);
999        }
1000
1001        Ok(())
1002    }
1003
1004    #[test]
1005    fn test_temp_scope_lookup_callable() -> Result<()> {
1006        let upcalls = HashMap::default();
1007        let mut global = GlobalSymtable::new(upcalls);
1008        let md = CallableMetadataBuilder::new("FOO").test_build();
1009        global.declare_user_callable(&VarRef::new("foo", None), md)?;
1010
1011        let mut local = global.enter_scope();
1012        {
1013            let temp = local.frozen();
1014            assert!(temp.get_callable(&SymbolKey::from("foo")).is_some());
1015            assert!(temp.get_callable(&SymbolKey::from("nope")).is_none());
1016        }
1017
1018        Ok(())
1019    }
1020
1021    #[test]
1022    fn test_multiple_scopes_independent_locals() -> Result<()> {
1023        let upcalls = HashMap::default();
1024        let mut global = GlobalSymtable::new(upcalls);
1025
1026        {
1027            let mut local = global.enter_scope();
1028            local.put_local(SymbolKey::from("x"), SymbolPrototype::Scalar(ExprType::Integer))?;
1029        }
1030
1031        {
1032            let mut local = global.enter_scope();
1033            // "x" should not exist in this new scope.
1034            let err = local.get_local_or_global(&VarRef::new("x", None)).unwrap_err();
1035            assert_eq!("Undefined global symbol x", err.to_string());
1036
1037            let reg =
1038                local.put_local(SymbolKey::from("y"), SymbolPrototype::Scalar(ExprType::Double))?;
1039            assert_eq!(Register::local(0).unwrap(), reg);
1040        }
1041
1042        Ok(())
1043    }
1044
1045    #[test]
1046    fn test_global_put_and_get_array() -> Result<()> {
1047        let upcalls = HashMap::default();
1048        let mut global = GlobalSymtable::new(upcalls);
1049
1050        let reg = global.put_global(
1051            SymbolKey::from("arr"),
1052            SymbolPrototype::Array(ArrayInfo { subtype: ExprType::Integer, ndims: 2 }),
1053        )?;
1054        assert_eq!(Register::global(0).unwrap(), reg);
1055
1056        let (got_reg, proto) = global.get_global(&VarRef::new("arr", None)).unwrap();
1057        assert_eq!(Register::global(0).unwrap(), got_reg);
1058        let SymbolPrototype::Array(info) = proto else { panic!("Expected Array prototype") };
1059        assert_eq!(ExprType::Integer, info.subtype);
1060        assert_eq!(2, info.ndims);
1061
1062        Ok(())
1063    }
1064
1065    #[test]
1066    fn test_local_put_and_get_array() -> Result<()> {
1067        let upcalls = HashMap::default();
1068        let mut global = GlobalSymtable::new(upcalls);
1069        let mut local = global.enter_scope();
1070
1071        let reg = local.put_local(
1072            SymbolKey::from("arr"),
1073            SymbolPrototype::Array(ArrayInfo { subtype: ExprType::Double, ndims: 1 }),
1074        )?;
1075        assert_eq!(Register::local(0).unwrap(), reg);
1076
1077        let (got_reg, proto) = local.get_local_or_global(&VarRef::new("arr", None)).unwrap();
1078        assert_eq!(Register::local(0).unwrap(), got_reg);
1079        let SymbolPrototype::Array(info) = proto else { panic!("Expected Array prototype") };
1080        assert_eq!(ExprType::Double, info.subtype);
1081        assert_eq!(1, info.ndims);
1082
1083        Ok(())
1084    }
1085}