1use crate::ast::{ExprType, VarRef};
20use crate::bytecode::{Register, RegisterScope};
21use crate::compiler::ids::HashMapWithIds;
22use crate::{CallableMetadata, bytecode};
23use std::cell::Cell;
24use std::cell::RefCell;
25use std::cmp::max;
26use std::collections::HashMap;
27use std::convert::TryFrom;
28use std::fmt;
29use std::rc::Rc;
30
31#[derive(Clone, Copy, Debug, PartialEq)]
33pub(super) struct ArrayInfo {
34 pub(super) subtype: ExprType,
36
37 pub(super) ndims: usize,
39}
40
41#[derive(Clone, Copy, Debug, PartialEq)]
43pub(super) enum SymbolPrototype {
44 Array(ArrayInfo),
46
47 Scalar(ExprType),
49}
50
51#[derive(Debug, thiserror::Error)]
53#[allow(missing_docs)] pub(super) enum Error {
55 #[error("Cannot redefine {0}")]
56 AlreadyDefined(VarRef),
57
58 #[error("Incompatible type annotation in {0} reference")]
59 IncompatibleTypeAnnotationInReference(VarRef),
60
61 #[error("Out of {0} registers")]
62 OutOfRegisters(RegisterScope),
63
64 #[error("Undefined {1} symbol {0}")]
65 UndefinedSymbol(VarRef, RegisterScope),
66}
67
68type Result<T> = std::result::Result<T, Error>;
70
71#[derive(Clone, Debug, Hash, Eq, Ord, PartialEq, PartialOrd)]
75pub struct SymbolKey(String);
76
77impl<R: AsRef<str>> From<R> for SymbolKey {
78 fn from(value: R) -> Self {
79 Self(value.as_ref().to_ascii_uppercase())
80 }
81}
82
83impl fmt::Display for SymbolKey {
84 fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
85 write!(f, "{}", self.0)
86 }
87}
88
89fn get_var<MKR>(
91 vref: &VarRef,
92 table: &HashMapWithIds<SymbolKey, SymbolPrototype, u8>,
93 make_register: MKR,
94 scope: RegisterScope,
95) -> Result<(Register, SymbolPrototype)>
96where
97 MKR: FnOnce(u8) -> std::result::Result<Register, bytecode::OutOfRegistersError>,
98{
99 let key = SymbolKey::from(&vref.name);
100 match table.get(&key) {
101 Some((SymbolPrototype::Array(info), reg)) => {
102 if !vref.accepts(info.subtype) {
103 return Err(Error::IncompatibleTypeAnnotationInReference(vref.clone()));
104 }
105
106 let reg = make_register(reg).map_err(|_| Error::OutOfRegisters(scope))?;
107 Ok((reg, SymbolPrototype::Array(*info)))
108 }
109
110 Some((SymbolPrototype::Scalar(etype), reg)) => {
111 if !vref.accepts(*etype) {
112 return Err(Error::IncompatibleTypeAnnotationInReference(vref.clone()));
113 }
114
115 let reg = make_register(reg).map_err(|_| Error::OutOfRegisters(scope))?;
116 Ok((reg, SymbolPrototype::Scalar(*etype)))
117 }
118
119 None => Err(Error::UndefinedSymbol(vref.clone(), scope)),
120 }
121}
122
123fn put_var<MKR>(
127 key: SymbolKey,
128 proto: SymbolPrototype,
129 table: &mut HashMapWithIds<SymbolKey, SymbolPrototype, u8>,
130 make_register: MKR,
131 scope: RegisterScope,
132) -> Result<Register>
133where
134 MKR: FnOnce(u8) -> std::result::Result<Register, bytecode::OutOfRegistersError>,
135{
136 match table.insert(key, proto) {
137 Some((None, reg)) => Ok(make_register(reg).map_err(|_| Error::OutOfRegisters(scope))?),
138
139 Some((Some(_old_proto), _reg)) => {
140 unreachable!("Cannot redefine symbol; caller must check for presence first");
141 }
142
143 None => Err(Error::OutOfRegisters(scope)),
144 }
145}
146
147#[derive(Clone)]
151pub(crate) struct GlobalSymtable {
152 globals: HashMapWithIds<SymbolKey, SymbolPrototype, u8>,
154
155 upcalls: HashMap<SymbolKey, Rc<CallableMetadata>>,
157
158 user_callables: HashMap<SymbolKey, Rc<CallableMetadata>>,
160}
161
162impl GlobalSymtable {
163 pub(crate) fn new(upcalls: HashMap<SymbolKey, Rc<CallableMetadata>>) -> Self {
165 Self { globals: HashMapWithIds::default(), upcalls, user_callables: HashMap::default() }
166 }
167
168 pub(crate) fn enter_scope(&mut self) -> LocalSymtable<'_> {
170 LocalSymtable::new(self)
171 }
172
173 pub(crate) fn get_global(&self, vref: &VarRef) -> Result<(Register, SymbolPrototype)> {
175 get_var(vref, &self.globals, Register::global, RegisterScope::Global)
176 }
177
178 pub(crate) fn put_global(
180 &mut self,
181 key: SymbolKey,
182 proto: SymbolPrototype,
183 ) -> Result<Register> {
184 put_var(key, proto, &mut self.globals, Register::global, RegisterScope::Global)
185 }
186
187 pub(crate) fn contains_global(&self, key: &SymbolKey) -> bool {
189 self.globals.get(key).is_some()
190 }
191
192 pub(crate) fn iter_globals(
194 &self,
195 ) -> impl Iterator<Item = (&SymbolKey, SymbolPrototype, u8)> + '_ {
196 self.globals.iter().map(|(k, v, i)| (k, *v, i))
197 }
198
199 pub(crate) fn declare_user_callable(
201 &mut self,
202 vref: &VarRef,
203 md: Rc<CallableMetadata>,
204 ) -> Result<()> {
205 let key = SymbolKey::from(&vref.name);
206 if self.globals.get(&key).is_some() {
207 return Err(Error::AlreadyDefined(vref.clone()));
208 }
209 if let Some(previous_md) = self.user_callables.insert(key.clone(), md.clone())
210 && previous_md != md
211 {
212 return Err(Error::AlreadyDefined(vref.clone()));
213 }
214 Ok(())
215 }
216
217 pub(crate) fn get_callable(&self, key: &SymbolKey) -> Option<Rc<CallableMetadata>> {
219 self.user_callables.get(key).or(self.upcalls.get(key)).cloned()
220 }
221}
222
223#[derive(Clone, Default)]
224pub(crate) struct LocalSymtableSnapshot {
225 locals: HashMapWithIds<SymbolKey, SymbolPrototype, u8>,
227
228 count_temps: u8,
232
233 active_temps: Rc<Cell<u8>>,
235}
236
237pub(crate) struct LocalSymtable<'a> {
242 symtable: &'a mut GlobalSymtable,
244
245 locals: HashMapWithIds<SymbolKey, SymbolPrototype, u8>,
247
248 count_temps: u8,
252
253 active_temps: Rc<Cell<u8>>,
255}
256
257impl<'a> LocalSymtable<'a> {
258 fn new(symtable: &'a mut GlobalSymtable) -> Self {
260 Self {
261 symtable,
262 locals: HashMapWithIds::default(),
263 count_temps: 0,
264 active_temps: Rc::from(Cell::new(0)),
265 }
266 }
267
268 pub(crate) fn save(self) -> LocalSymtableSnapshot {
271 LocalSymtableSnapshot {
272 locals: self.locals,
273 count_temps: self.count_temps,
274 active_temps: self.active_temps,
275 }
276 }
277
278 pub(crate) fn restore(
281 symtable: &'a mut GlobalSymtable,
282 snapshot: LocalSymtableSnapshot,
283 ) -> Self {
284 Self {
285 symtable,
286 locals: snapshot.locals,
287 count_temps: snapshot.count_temps,
288 active_temps: snapshot.active_temps,
289 }
290 }
291
292 pub(crate) fn global(&mut self) -> &mut GlobalSymtable {
294 self.symtable
295 }
296
297 pub(crate) fn declare_user_callable(
299 &mut self,
300 vref: &VarRef,
301 md: Rc<CallableMetadata>,
302 ) -> Result<()> {
303 self.symtable.declare_user_callable(vref, md)
304 }
305
306 pub(crate) fn frozen(&mut self) -> TempSymtable<'_, 'a> {
308 TempSymtable::new(self)
309 }
310
311 pub(crate) fn with_reserved_temp<T, E, ME, F>(
313 &mut self,
314 map_error: ME,
315 f: F,
316 ) -> std::result::Result<T, E>
317 where
318 ME: Fn(Error) -> E,
319 F: FnOnce(Register, &mut TempSymtable<'_, 'a>) -> std::result::Result<T, E>,
320 {
321 struct TempReservationGuard {
322 active_temps: Rc<Cell<u8>>,
323 }
324
325 impl Drop for TempReservationGuard {
326 fn drop(&mut self) {
327 let active_temps = self.active_temps.get();
328 debug_assert!(active_temps > 0);
329 self.active_temps.set(active_temps - 1);
330 }
331 }
332
333 let nlocals = u8::try_from(self.locals.len())
334 .map_err(|_| map_error(Error::OutOfRegisters(RegisterScope::Local)))?;
335 let first_temp = self.active_temps.get();
336 let new_active_temps = first_temp
337 .checked_add(1)
338 .ok_or(map_error(Error::OutOfRegisters(RegisterScope::Temp)))?;
339 self.active_temps.set(new_active_temps);
340 self.count_temps = max(self.count_temps, new_active_temps);
341 let _guard = TempReservationGuard { active_temps: self.active_temps.clone() };
342
343 let reg_idx = u8::try_from(usize::from(nlocals) + usize::from(first_temp))
344 .map_err(|_| map_error(Error::OutOfRegisters(RegisterScope::Temp)))?;
345 let reg = Register::local(reg_idx)
346 .map_err(|_| map_error(Error::OutOfRegisters(RegisterScope::Temp)))?;
347
348 let mut temp = self.frozen();
349 f(reg, &mut temp)
350 }
351
352 pub(crate) fn put_global(
354 &mut self,
355 key: SymbolKey,
356 proto: SymbolPrototype,
357 ) -> Result<Register> {
358 self.symtable.put_global(key, proto)
359 }
360
361 pub(crate) fn get_local_or_global(&self, vref: &VarRef) -> Result<(Register, SymbolPrototype)> {
363 match get_var(vref, &self.locals, Register::local, RegisterScope::Local) {
364 Ok(local) => Ok(local),
365 Err(Error::UndefinedSymbol(..)) => self.symtable.get_global(vref),
366 Err(e) => Err(e),
367 }
368 }
369
370 pub(crate) fn get_callable(&self, key: &SymbolKey) -> Option<Rc<CallableMetadata>> {
372 self.symtable.get_callable(key)
373 }
374
375 pub(crate) fn put_local(&mut self, key: SymbolKey, proto: SymbolPrototype) -> Result<Register> {
377 put_var(key, proto, &mut self.locals, Register::local, RegisterScope::Local)
378 }
379
380 pub(crate) fn contains_local(&self, key: &SymbolKey) -> bool {
382 self.locals.get(key).is_some()
383 }
384
385 pub(crate) fn contains_global(&self, key: &SymbolKey) -> bool {
387 self.symtable.contains_global(key)
388 }
389
390 pub(crate) fn iter_globals(
392 &self,
393 ) -> impl Iterator<Item = (SymbolKey, SymbolPrototype, u8)> + '_ {
394 self.symtable.iter_globals().map(|(k, v, i)| (k.clone(), v, i))
395 }
396
397 pub(crate) fn iter_locals(
399 &self,
400 ) -> impl Iterator<Item = (SymbolKey, SymbolPrototype, u8)> + '_ {
401 self.locals.iter().map(|(k, v, i)| (k.clone(), *v, i))
402 }
403
404 pub(crate) fn fixup_local_type(&mut self, vref: &VarRef, new_etype: ExprType) -> Result<()> {
408 let key = SymbolKey::from(&vref.name);
409 match self.locals.get_mut(&key) {
411 Some((SymbolPrototype::Array(_), _)) | None => {
412 Err(Error::UndefinedSymbol(vref.clone(), RegisterScope::Local))
413 }
414
415 Some((SymbolPrototype::Scalar(etype), _reg)) => {
416 *etype = new_etype;
417 Ok(())
418 }
419 }
420 }
421}
422
423pub(crate) struct TempSymtable<'temp, 'local> {
431 symtable: &'temp mut LocalSymtable<'local>,
433
434 base_temp: u8,
436
437 next_temp: Rc<RefCell<u8>>,
439
440 count_temps: Rc<RefCell<u8>>,
442}
443
444impl<'temp, 'local> Drop for TempSymtable<'temp, 'local> {
445 fn drop(&mut self) {
446 debug_assert_eq!(self.base_temp, *self.next_temp.borrow(), "Unbalanced temp drops");
447 self.symtable.count_temps = max(self.symtable.count_temps, *self.count_temps.borrow());
448 }
449}
450
451impl<'temp, 'local> TempSymtable<'temp, 'local> {
452 fn new(symtable: &'temp mut LocalSymtable<'local>) -> Self {
454 let base_temp = symtable.active_temps.get();
455 Self {
456 symtable,
457 base_temp,
458 next_temp: Rc::from(RefCell::from(base_temp)),
459 count_temps: Rc::from(RefCell::from(base_temp)),
460 }
461 }
462
463 pub(crate) fn get_local_or_global(&self, vref: &VarRef) -> Result<(Register, SymbolPrototype)> {
465 self.symtable.get_local_or_global(vref)
466 }
467
468 pub(crate) fn get_callable(&self, key: &SymbolKey) -> Option<Rc<CallableMetadata>> {
470 self.symtable.get_callable(key)
471 }
472
473 pub(crate) fn temp_scope(&self) -> TempScope {
475 let nlocals = u8::try_from(self.symtable.locals.len())
476 .expect("Cannot have allocated more locals than u8");
477 TempScope {
478 base_temp: *self.next_temp.borrow(),
479 nlocals,
480 ntemps: 0,
481 next_temp: self.next_temp.clone(),
482 count_temps: self.count_temps.clone(),
483 }
484 }
485}
486
487pub(crate) struct TempScope {
491 base_temp: u8,
493
494 nlocals: u8,
496
497 ntemps: u8,
499
500 next_temp: Rc<RefCell<u8>>,
502
503 count_temps: Rc<RefCell<u8>>,
505}
506
507impl Drop for TempScope {
508 fn drop(&mut self) {
509 let mut next_temp = self.next_temp.borrow_mut();
510 debug_assert!(*next_temp >= self.ntemps);
511 *next_temp -= self.ntemps;
512 }
513}
514
515impl TempScope {
516 pub(crate) fn first(&mut self) -> Result<Register> {
518 let reg = u8::try_from(usize::from(self.nlocals) + usize::from(self.base_temp))
519 .map_err(|_| Error::OutOfRegisters(RegisterScope::Temp))?;
520 Register::local(reg).map_err(|_| Error::OutOfRegisters(RegisterScope::Temp))
521 }
522
523 pub(crate) fn alloc(&mut self) -> Result<Register> {
525 let temp;
526 let new_next_temp;
527 {
528 let mut next_temp = self.next_temp.borrow_mut();
529 temp = *next_temp;
530 self.ntemps += 1;
531 new_next_temp = match next_temp.checked_add(1) {
532 Some(reg) => reg,
533 None => return Err(Error::OutOfRegisters(RegisterScope::Temp)),
534 };
535 *next_temp = new_next_temp;
536 }
537
538 {
539 let mut count_temps = self.count_temps.borrow_mut();
540 *count_temps = max(*count_temps, new_next_temp);
541 }
542
543 match u8::try_from(usize::from(self.nlocals) + usize::from(temp)) {
544 Ok(reg) => {
545 Ok(Register::local(reg).map_err(|_| Error::OutOfRegisters(RegisterScope::Temp))?)
546 }
547 Err(_) => Err(Error::OutOfRegisters(RegisterScope::Temp)),
548 }
549 }
550}
551
552#[cfg(test)]
553mod tests {
554 use super::*;
555 use crate::CallableMetadataBuilder;
556
557 #[test]
558 fn test_symbol_key_case_insensitive() {
559 assert_eq!(SymbolKey::from("foo"), SymbolKey::from("FOO"));
560 assert_eq!(SymbolKey::from("Foo"), SymbolKey::from("fOo"));
561 }
562
563 #[test]
564 fn test_symbol_key_display() {
565 assert_eq!("FOO", format!("{}", SymbolKey::from("foo")));
566 }
567
568 #[test]
569 fn test_global_put_and_get() -> Result<()> {
570 let upcalls = HashMap::default();
571 let mut global = GlobalSymtable::new(upcalls);
572
573 let reg =
574 global.put_global(SymbolKey::from("x"), SymbolPrototype::Scalar(ExprType::Integer))?;
575 assert_eq!(Register::global(0).unwrap(), reg);
576
577 let reg =
578 global.put_global(SymbolKey::from("y"), SymbolPrototype::Scalar(ExprType::Text))?;
579 assert_eq!(Register::global(1).unwrap(), reg);
580
581 let (reg, proto) = global.get_global(&VarRef::new("x", None))?;
583 assert_eq!(Register::global(0).unwrap(), reg);
584 assert_eq!(SymbolPrototype::Scalar(ExprType::Integer), proto);
585
586 let (reg, proto) = global.get_global(&VarRef::new("y", Some(ExprType::Text)))?;
588 assert_eq!(Register::global(1).unwrap(), reg);
589 assert_eq!(SymbolPrototype::Scalar(ExprType::Text), proto);
590
591 Ok(())
592 }
593
594 #[test]
595 fn test_global_get_case_insensitive() -> Result<()> {
596 let upcalls = HashMap::default();
597 let mut global = GlobalSymtable::new(upcalls);
598 global.put_global(SymbolKey::from("MyVar"), SymbolPrototype::Scalar(ExprType::Double))?;
599
600 let (reg, proto) = global.get_global(&VarRef::new("myvar", None))?;
601 assert_eq!(Register::global(0).unwrap(), reg);
602 assert_eq!(SymbolPrototype::Scalar(ExprType::Double), proto);
603
604 let (reg2, _) = global.get_global(&VarRef::new("MYVAR", None))?;
605 assert_eq!(reg, reg2);
606 Ok(())
607 }
608
609 #[test]
610 fn test_global_get_incompatible_type() {
611 let upcalls = HashMap::default();
612 let mut global = GlobalSymtable::new(upcalls);
613 global
614 .put_global(SymbolKey::from("x"), SymbolPrototype::Scalar(ExprType::Integer))
615 .unwrap();
616
617 let err = global.get_global(&VarRef::new("x", Some(ExprType::Text))).unwrap_err();
618 assert_eq!("Incompatible type annotation in x$ reference", err.to_string());
619 }
620
621 #[test]
622 fn test_global_get_undefined() {
623 let upcalls = HashMap::default();
624 let global = GlobalSymtable::new(upcalls);
625
626 let err = global.get_global(&VarRef::new("x", None)).unwrap_err();
627 assert_eq!("Undefined global symbol x", err.to_string());
628 }
629
630 #[test]
631 fn test_local_put_and_get() -> Result<()> {
632 let upcalls = HashMap::default();
633 let mut global = GlobalSymtable::new(upcalls);
634 let mut local = global.enter_scope();
635
636 let reg =
637 local.put_local(SymbolKey::from("a"), SymbolPrototype::Scalar(ExprType::Boolean))?;
638 assert_eq!(Register::local(0).unwrap(), reg);
639
640 let reg =
641 local.put_local(SymbolKey::from("b"), SymbolPrototype::Scalar(ExprType::Double))?;
642 assert_eq!(Register::local(1).unwrap(), reg);
643
644 let (reg, proto) = local.get_local_or_global(&VarRef::new("a", None))?;
645 assert_eq!(Register::local(0).unwrap(), reg);
646 assert_eq!(SymbolPrototype::Scalar(ExprType::Boolean), proto);
647
648 Ok(())
649 }
650
651 #[test]
652 fn test_local_shadows_global() -> Result<()> {
653 let upcalls = HashMap::default();
654 let mut global = GlobalSymtable::new(upcalls);
655 global.put_global(SymbolKey::from("x"), SymbolPrototype::Scalar(ExprType::Integer))?;
656
657 let mut local = global.enter_scope();
658 local.put_local(SymbolKey::from("x"), SymbolPrototype::Scalar(ExprType::Text))?;
659
660 let (reg, proto) = local.get_local_or_global(&VarRef::new("x", None))?;
661 assert_eq!(Register::local(0).unwrap(), reg);
662 assert_eq!(SymbolPrototype::Scalar(ExprType::Text), proto);
663
664 Ok(())
665 }
666
667 #[test]
668 fn test_local_falls_through_to_global() -> Result<()> {
669 let upcalls = HashMap::default();
670 let mut global = GlobalSymtable::new(upcalls);
671 global.put_global(SymbolKey::from("g"), SymbolPrototype::Scalar(ExprType::Integer))?;
672
673 let local = global.enter_scope();
674 let (reg, proto) = local.get_local_or_global(&VarRef::new("g", None))?;
675 assert_eq!(Register::global(0).unwrap(), reg);
676 assert_eq!(SymbolPrototype::Scalar(ExprType::Integer), proto);
677
678 Ok(())
679 }
680
681 #[test]
682 fn test_local_get_undefined() {
683 let upcalls = HashMap::default();
684 let mut global = GlobalSymtable::new(upcalls);
685 let local = global.enter_scope();
686
687 let err = local.get_local_or_global(&VarRef::new("nope", None)).unwrap_err();
688 assert_eq!("Undefined global symbol nope", err.to_string());
689 }
690
691 #[test]
692 fn test_local_put_global_through_local() -> Result<()> {
693 let upcalls = HashMap::default();
694 let mut global = GlobalSymtable::new(upcalls);
695 let mut local = global.enter_scope();
696
697 let reg =
698 local.put_global(SymbolKey::from("g"), SymbolPrototype::Scalar(ExprType::Integer))?;
699 assert_eq!(Register::global(0).unwrap(), reg);
700
701 let (reg, proto) = local.get_local_or_global(&VarRef::new("g", None))?;
703 assert_eq!(Register::global(0).unwrap(), reg);
704 assert_eq!(SymbolPrototype::Scalar(ExprType::Integer), proto);
705
706 Ok(())
707 }
708
709 #[test]
710 fn test_fixup_local_type() -> Result<()> {
711 let upcalls = HashMap::default();
712 let mut global = GlobalSymtable::new(upcalls);
713 let mut local = global.enter_scope();
714
715 local.put_local(SymbolKey::from("x"), SymbolPrototype::Scalar(ExprType::Integer))?;
716 local.fixup_local_type(&VarRef::new("x", None), ExprType::Double)?;
717
718 let (_, proto) = local.get_local_or_global(&VarRef::new("x", None))?;
719 assert_eq!(SymbolPrototype::Scalar(ExprType::Double), proto);
720
721 Ok(())
722 }
723
724 #[test]
725 fn test_fixup_local_type_undefined() {
726 let upcalls = HashMap::default();
727 let mut global = GlobalSymtable::new(upcalls);
728 let mut local = global.enter_scope();
729
730 let err =
731 local.fixup_local_type(&VarRef::new("nope", None), ExprType::Integer).unwrap_err();
732 assert_eq!("Undefined local symbol nope", err.to_string());
733 }
734
735 #[test]
736 fn test_define_and_get_user_callable() -> Result<()> {
737 let upcalls = HashMap::default();
738 let mut global = GlobalSymtable::new(upcalls);
739
740 let md = CallableMetadataBuilder::new("MY_FUNC")
741 .with_return_type(ExprType::Integer)
742 .test_build();
743 global.declare_user_callable(&VarRef::new("my_func", None), md)?;
744
745 let found = global.get_callable(&SymbolKey::from("my_func"));
746 assert!(found.is_some());
747 assert_eq!("MY_FUNC", found.unwrap().name());
748
749 Ok(())
750 }
751
752 #[test]
753 fn test_define_user_callable_already_defined_but_is_compatible() -> Result<()> {
754 let upcalls = HashMap::default();
755 let mut global = GlobalSymtable::new(upcalls);
756
757 let md = CallableMetadataBuilder::new("DUP").test_build();
758 global.declare_user_callable(&VarRef::new("dup", None), md)?;
759
760 let md2 = CallableMetadataBuilder::new("DUP").test_build();
761 global.declare_user_callable(&VarRef::new("dup", None), md2)?;
762
763 Ok(())
764 }
765
766 #[test]
767 fn test_define_user_callable_already_defined_but_is_incompatible() -> Result<()> {
768 let upcalls = HashMap::default();
769 let mut global = GlobalSymtable::new(upcalls);
770
771 let md = CallableMetadataBuilder::new("DUP").test_build();
772 global.declare_user_callable(&VarRef::new("dup", None), md)?;
773
774 let md2 =
775 CallableMetadataBuilder::new("DUP").with_return_type(ExprType::Integer).test_build();
776 let err = global.declare_user_callable(&VarRef::new("dup", None), md2).unwrap_err();
777 assert_eq!("Cannot redefine dup", err.to_string());
778
779 Ok(())
780 }
781
782 #[test]
783 fn test_define_user_callable_via_local() -> Result<()> {
784 let upcalls = HashMap::default();
785 let mut global = GlobalSymtable::new(upcalls);
786 let mut local = global.enter_scope();
787
788 let md = CallableMetadataBuilder::new("SUB1").test_build();
789 local.declare_user_callable(&VarRef::new("sub1", None), md)?;
790
791 let found = local.get_callable(&SymbolKey::from("sub1"));
792 assert!(found.is_some());
793
794 Ok(())
795 }
796
797 #[test]
798 fn test_get_callable_upcall() {
799 let key = SymbolKey::from("BUILTIN");
800 let md = CallableMetadataBuilder::new("BUILTIN").test_build();
801 let mut upcalls_map = HashMap::new();
802 upcalls_map.insert(key, md);
803
804 let global = GlobalSymtable::new(upcalls_map);
805 let found = global.get_callable(&SymbolKey::from("builtin"));
806 assert!(found.is_some());
807 assert_eq!("BUILTIN", found.unwrap().name());
808 }
809
810 #[test]
811 fn test_user_callable_shadows_upcall() {
812 let key = SymbolKey::from("SHARED");
813 let builtin_md =
814 CallableMetadataBuilder::new("SHARED").with_return_type(ExprType::Boolean).test_build();
815 let mut upcalls_map = HashMap::new();
816 upcalls_map.insert(key, builtin_md);
817
818 let mut global = GlobalSymtable::new(upcalls_map);
819 let user_md =
820 CallableMetadataBuilder::new("SHARED").with_return_type(ExprType::Integer).test_build();
821 global.declare_user_callable(&VarRef::new("shared", None), user_md).unwrap();
822
823 let found = global.get_callable(&SymbolKey::from("shared")).unwrap();
824 assert_eq!(Some(ExprType::Integer), found.return_type());
825 }
826
827 #[test]
828 fn test_get_callable_not_found() {
829 let upcalls = HashMap::default();
830 let global = GlobalSymtable::new(upcalls);
831 assert!(global.get_callable(&SymbolKey::from("nope")).is_none());
832 }
833
834 #[test]
835 fn test_temp_scope_first() -> Result<()> {
836 let upcalls = HashMap::default();
837 let mut global = GlobalSymtable::new(upcalls);
838 let mut local = global.enter_scope();
839 local.put_local(SymbolKey::from("a"), SymbolPrototype::Scalar(ExprType::Integer))?;
840 local.put_local(SymbolKey::from("b"), SymbolPrototype::Scalar(ExprType::Integer))?;
841 {
842 let temp = local.frozen();
843 let mut scope = temp.temp_scope();
844 assert_eq!(Register::local(2).unwrap(), scope.first()?);
845 }
846 Ok(())
847 }
848
849 #[test]
850 fn test_temp_scope_first_no_locals() -> Result<()> {
851 let upcalls = HashMap::default();
852 let mut global = GlobalSymtable::new(upcalls);
853 let mut local = global.enter_scope();
854 {
855 let temp = local.frozen();
856 let mut scope = temp.temp_scope();
857 assert_eq!(Register::local(0).unwrap(), scope.first()?);
858 }
859 Ok(())
860 }
861
862 #[test]
863 fn test_temp_scope_first_with_outer_allocation() -> Result<()> {
864 let upcalls = HashMap::default();
865 let mut global = GlobalSymtable::new(upcalls);
866 let mut local = global.enter_scope();
867 local.put_local(SymbolKey::from("a"), SymbolPrototype::Scalar(ExprType::Integer))?;
868 {
869 let temp = local.frozen();
870 let mut outer = temp.temp_scope();
871 assert_eq!(Register::local(1).unwrap(), outer.alloc()?);
872
873 let mut inner = temp.temp_scope();
874 assert_eq!(Register::local(2).unwrap(), inner.first()?);
875 }
876 Ok(())
877 }
878
879 #[test]
880 fn test_temp_scope() -> Result<()> {
881 let upcalls = HashMap::default();
882 let mut global = GlobalSymtable::new(upcalls);
883 let mut local = global.enter_scope();
884 assert_eq!(
885 Register::local(0).unwrap(),
886 local.put_local(SymbolKey::from("foo"), SymbolPrototype::Scalar(ExprType::Integer))?
887 );
888 {
889 let temp = local.frozen();
890 {
891 let mut scope = temp.temp_scope();
892 assert_eq!(Register::local(1).unwrap(), scope.alloc()?);
893 {
894 let mut scope = temp.temp_scope();
895 assert_eq!(Register::local(2).unwrap(), scope.alloc()?);
896 assert_eq!(Register::local(3).unwrap(), scope.alloc()?);
897 assert_eq!(Register::local(4).unwrap(), scope.alloc()?);
898 }
899 {
900 let mut scope = temp.temp_scope();
901 assert_eq!(Register::local(2).unwrap(), scope.alloc()?);
902 assert_eq!(Register::local(3).unwrap(), scope.alloc()?);
903 }
904 assert_eq!(Register::local(2).unwrap(), scope.alloc()?);
905 }
906 }
907 {
908 let temp = local.frozen();
909 {
910 let mut scope = temp.temp_scope();
911 assert_eq!(Register::local(1).unwrap(), scope.alloc()?);
912 }
913 }
914 Ok(())
915 }
916
917 #[test]
918 fn test_with_reserved_temp_register_index() -> Result<()> {
919 let upcalls = HashMap::default();
920 let mut global = GlobalSymtable::new(upcalls);
921 let mut local = global.enter_scope();
922 local.put_local(SymbolKey::from("a"), SymbolPrototype::Scalar(ExprType::Integer))?;
923 local.put_local(SymbolKey::from("b"), SymbolPrototype::Scalar(ExprType::Integer))?;
924
925 local.with_reserved_temp(
926 |e| e,
927 |reg, _| {
928 assert_eq!(Register::local(2).unwrap(), reg);
929 Ok(())
930 },
931 )?;
932
933 Ok(())
934 }
935
936 #[test]
937 fn test_with_reserved_temp_shifts_temp_scope_base() -> Result<()> {
938 let upcalls = HashMap::default();
939 let mut global = GlobalSymtable::new(upcalls);
940 let mut local = global.enter_scope();
941 local.put_local(SymbolKey::from("a"), SymbolPrototype::Scalar(ExprType::Integer))?;
942
943 local.with_reserved_temp(
944 |e| e,
945 |reserved, temp| {
946 assert_eq!(Register::local(1).unwrap(), reserved);
947 let mut scope = temp.temp_scope();
948 assert_eq!(Register::local(2).unwrap(), scope.alloc()?);
949 Ok(())
950 },
951 )?;
952
953 Ok(())
954 }
955
956 #[test]
957 fn test_with_reserved_temp_released_after_error() {
958 let upcalls = HashMap::default();
959 let mut global = GlobalSymtable::new(upcalls);
960 let mut local = global.enter_scope();
961
962 let err = local
963 .with_reserved_temp(
964 |e| e,
965 |_, _| Err::<(), Error>(Error::OutOfRegisters(RegisterScope::Temp)),
966 )
967 .unwrap_err();
968 assert_eq!("Out of temp registers", err.to_string());
969
970 local
971 .with_reserved_temp(
972 |e| e,
973 |reg, _| {
974 assert_eq!(Register::local(0).unwrap(), reg);
975 Ok(())
976 },
977 )
978 .unwrap();
979 }
980
981 #[test]
982 fn test_temp_scope_lookup_vars() -> Result<()> {
983 let upcalls = HashMap::default();
984 let mut global = GlobalSymtable::new(upcalls);
985 global.put_global(SymbolKey::from("g"), SymbolPrototype::Scalar(ExprType::Integer))?;
986 let mut local = global.enter_scope();
987 local.put_local(SymbolKey::from("l"), SymbolPrototype::Scalar(ExprType::Text))?;
988
989 {
990 let temp = local.frozen();
991
992 let (reg, proto) = temp.get_local_or_global(&VarRef::new("l", None))?;
993 assert_eq!(Register::local(0).unwrap(), reg);
994 assert_eq!(SymbolPrototype::Scalar(ExprType::Text), proto);
995
996 let (reg, proto) = temp.get_local_or_global(&VarRef::new("g", None))?;
997 assert_eq!(Register::global(0).unwrap(), reg);
998 assert_eq!(SymbolPrototype::Scalar(ExprType::Integer), proto);
999 }
1000
1001 Ok(())
1002 }
1003
1004 #[test]
1005 fn test_temp_scope_lookup_callable() -> Result<()> {
1006 let upcalls = HashMap::default();
1007 let mut global = GlobalSymtable::new(upcalls);
1008 let md = CallableMetadataBuilder::new("FOO").test_build();
1009 global.declare_user_callable(&VarRef::new("foo", None), md)?;
1010
1011 let mut local = global.enter_scope();
1012 {
1013 let temp = local.frozen();
1014 assert!(temp.get_callable(&SymbolKey::from("foo")).is_some());
1015 assert!(temp.get_callable(&SymbolKey::from("nope")).is_none());
1016 }
1017
1018 Ok(())
1019 }
1020
1021 #[test]
1022 fn test_multiple_scopes_independent_locals() -> Result<()> {
1023 let upcalls = HashMap::default();
1024 let mut global = GlobalSymtable::new(upcalls);
1025
1026 {
1027 let mut local = global.enter_scope();
1028 local.put_local(SymbolKey::from("x"), SymbolPrototype::Scalar(ExprType::Integer))?;
1029 }
1030
1031 {
1032 let mut local = global.enter_scope();
1033 let err = local.get_local_or_global(&VarRef::new("x", None)).unwrap_err();
1035 assert_eq!("Undefined global symbol x", err.to_string());
1036
1037 let reg =
1038 local.put_local(SymbolKey::from("y"), SymbolPrototype::Scalar(ExprType::Double))?;
1039 assert_eq!(Register::local(0).unwrap(), reg);
1040 }
1041
1042 Ok(())
1043 }
1044
1045 #[test]
1046 fn test_global_put_and_get_array() -> Result<()> {
1047 let upcalls = HashMap::default();
1048 let mut global = GlobalSymtable::new(upcalls);
1049
1050 let reg = global.put_global(
1051 SymbolKey::from("arr"),
1052 SymbolPrototype::Array(ArrayInfo { subtype: ExprType::Integer, ndims: 2 }),
1053 )?;
1054 assert_eq!(Register::global(0).unwrap(), reg);
1055
1056 let (got_reg, proto) = global.get_global(&VarRef::new("arr", None)).unwrap();
1057 assert_eq!(Register::global(0).unwrap(), got_reg);
1058 let SymbolPrototype::Array(info) = proto else { panic!("Expected Array prototype") };
1059 assert_eq!(ExprType::Integer, info.subtype);
1060 assert_eq!(2, info.ndims);
1061
1062 Ok(())
1063 }
1064
1065 #[test]
1066 fn test_local_put_and_get_array() -> Result<()> {
1067 let upcalls = HashMap::default();
1068 let mut global = GlobalSymtable::new(upcalls);
1069 let mut local = global.enter_scope();
1070
1071 let reg = local.put_local(
1072 SymbolKey::from("arr"),
1073 SymbolPrototype::Array(ArrayInfo { subtype: ExprType::Double, ndims: 1 }),
1074 )?;
1075 assert_eq!(Register::local(0).unwrap(), reg);
1076
1077 let (got_reg, proto) = local.get_local_or_global(&VarRef::new("arr", None)).unwrap();
1078 assert_eq!(Register::local(0).unwrap(), got_reg);
1079 let SymbolPrototype::Array(info) = proto else { panic!("Expected Array prototype") };
1080 assert_eq!(ExprType::Double, info.subtype);
1081 assert_eq!(1, info.ndims);
1082
1083 Ok(())
1084 }
1085}