use std::{cell::RefCell, collections::HashSet, rc::Rc};
use crate::{InferFailReason, LuaTypeDeclId};
pub type InferGuardRef = Rc<InferGuard>;
#[derive(Debug, Clone)]
pub struct InferGuard {
current: RefCell<Option<HashSet<LuaTypeDeclId>>>,
parent: Option<Rc<InferGuard>>,
}
impl InferGuard {
pub fn new() -> Rc<Self> {
Rc::new(Self {
current: RefCell::new(None),
parent: None,
})
}
pub fn fork(self: &Rc<Self>) -> Rc<Self> {
Rc::new(Self {
current: RefCell::new(None), parent: Some(Rc::clone(self)),
})
}
pub fn check(&self, type_id: &LuaTypeDeclId) -> Result<(), InferFailReason> {
if self.contains_in_parents(type_id) {
return Err(InferFailReason::RecursiveInfer);
}
let mut current_opt = self.current.borrow_mut();
let current = current_opt.get_or_insert_with(HashSet::default);
if current.contains(type_id) {
return Err(InferFailReason::RecursiveInfer);
}
current.insert(type_id.clone());
Ok(())
}
fn contains_in_parents(&self, type_id: &LuaTypeDeclId) -> bool {
let mut current_parent = self.parent.as_ref();
while let Some(parent) = current_parent {
if let Some(ref set) = *parent.current.borrow() {
if set.contains(type_id) {
return true;
}
}
current_parent = parent.parent.as_ref();
}
false
}
pub fn contains(&self, type_id: &LuaTypeDeclId) -> bool {
if let Some(ref set) = *self.current.borrow() {
if set.contains(type_id) {
return true;
}
}
self.contains_in_parents(type_id)
}
pub fn current_depth(&self) -> usize {
self.current.borrow().as_ref().map_or(0, |set| set.len())
}
pub fn total_depth(&self) -> usize {
let mut depth = self.current_depth();
let mut current_parent = self.parent.as_ref();
while let Some(parent) = current_parent {
depth += parent.current_depth();
current_parent = parent.parent.as_ref();
}
depth
}
pub fn level(&self) -> usize {
let mut level = 0;
let mut current_parent = self.parent.as_ref();
while let Some(parent) = current_parent {
level += 1;
current_parent = parent.parent.as_ref();
}
level
}
#[cfg(test)]
pub fn has_allocated(&self) -> bool {
self.current.borrow().is_some()
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_lazy_allocation() {
let root = InferGuard::new();
assert!(!root.has_allocated(), "New guard should not allocate");
let child = root.fork();
assert!(!child.has_allocated(), "Fork should not allocate memory");
let type_b = LuaTypeDeclId::global("TestTypeB");
child.check(&type_b).unwrap();
assert!(
child.has_allocated(),
"Check should trigger lazy allocation"
);
assert!(!root.has_allocated(), "Root should not be affected");
}
#[test]
fn test_fork_without_write() {
let root = InferGuard::new();
let type_a = LuaTypeDeclId::global("TestTypeA");
root.check(&type_a).unwrap();
let child1 = root.fork();
let child2 = root.fork();
let grandchild = child1.fork();
assert!(!child1.has_allocated());
assert!(!child2.has_allocated());
assert!(!grandchild.has_allocated());
assert!(child1.contains(&type_a));
assert!(child2.contains(&type_a));
assert!(grandchild.contains(&type_a));
}
#[test]
fn test_recursive_detection() {
let root = InferGuard::new();
let type_a = LuaTypeDeclId::global("TestTypeA");
assert!(root.check(&type_a).is_ok());
assert!(root.check(&type_a).is_err());
}
#[test]
fn test_parent_chain_detection() {
let root = InferGuard::new();
let type_a = LuaTypeDeclId::global("TestTypeA");
let type_b = LuaTypeDeclId::global("TestTypeB");
root.check(&type_a).unwrap();
let child = root.fork();
assert!(child.check(&type_a).is_err());
assert!(child.check(&type_b).is_ok());
let grandchild = child.fork();
assert!(grandchild.check(&type_a).is_err());
assert!(grandchild.check(&type_b).is_err());
}
#[test]
fn test_memory_efficiency() {
let root = InferGuard::new();
let type_a = LuaTypeDeclId::global("TestTypeA");
root.check(&type_a).unwrap();
let mut guards = vec![root];
for _ in 0..10 {
let child = guards.last().unwrap().fork();
guards.push(child);
}
for (i, guard) in guards.iter().enumerate() {
if i == 0 {
assert!(guard.has_allocated(), "Root should be allocated");
} else if i < guards.len() - 1 {
assert!(
!guard.has_allocated(),
"Intermediate fork {} should not allocate",
i
);
}
}
}
}