use indexmap::IndexMap;
use std::cell::Cell;
use std::fmt;
use std::hash::Hash;
use std::sync::Arc;
use slotmap::{SlotMap, new_key_type};
use super::Bytecode;
use super::LuaType;
use super::RuntimeCaches;
use super::Table;
use super::Val;
#[derive(Clone, Debug)]
pub(crate) enum Upvalue {
Open(usize),
Closed(Val),
}
#[derive(Clone, Copy, Debug, PartialEq, Eq)]
pub(crate) struct UpvalueRef(u32);
impl UpvalueRef {
fn new(idx: u32) -> Self {
Self(idx)
}
fn index(self) -> usize {
self.0 as usize
}
}
pub(crate) struct UpvaluePool {
slots: Vec<Upvalue>,
}
impl Default for UpvaluePool {
fn default() -> Self {
Self::new()
}
}
impl UpvaluePool {
pub(super) fn new() -> Self {
Self {
slots: Vec::with_capacity(64),
}
}
pub(super) fn alloc(&mut self, upvalue: Upvalue) -> UpvalueRef {
let idx = self.slots.len() as u32;
self.slots.push(upvalue);
UpvalueRef::new(idx)
}
#[inline]
pub(super) fn get(&self, uv_ref: UpvalueRef) -> &Upvalue {
&self.slots[uv_ref.index()]
}
#[inline]
pub(super) fn get_mut(&mut self, uv_ref: UpvalueRef) -> &mut Upvalue {
&mut self.slots[uv_ref.index()]
}
}
#[derive(Clone, Debug)]
pub(super) struct Closure {
pub(super) bytecode: Arc<Bytecode>,
pub(super) caches: Arc<RuntimeCaches>,
pub(super) upvalues: Vec<UpvalueRef>,
}
pub(super) enum RawObject {
LuaFn(Box<Closure>),
Table(Table),
}
impl RawObject {
#[must_use]
pub(super) const fn typ(&self) -> LuaType {
match self {
RawObject::LuaFn(_) => LuaType::Function,
RawObject::Table(_) => LuaType::Table,
}
}
}
#[derive(Clone, Copy, PartialEq, Eq)]
pub(crate) enum Color {
Unmarked,
Reachable,
}
pub(super) struct WrappedObject {
pub(super) raw: RawObject,
pub(super) color: Cell<Color>,
}
new_key_type! {
pub struct ObjectKey;
}
new_key_type! {
pub struct StringKey;
}
#[derive(Clone, Copy, Debug, Eq, Hash, PartialEq)]
pub(crate) struct ObjectPtr(pub(crate) ObjectKey);
impl ObjectPtr {
pub(super) fn typ(self, heap: &GcHeap) -> LuaType {
heap.get(self).raw.typ()
}
}
impl fmt::Display for ObjectPtr {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
write!(f, "object: {:?}", self.0)
}
}
pub(crate) struct GcHeap {
objects: SlotMap<ObjectKey, WrappedObject>,
threshold: usize,
strings: StringPool,
}
impl GcHeap {
pub(super) fn with_threshold(threshold: usize) -> Self {
Self {
objects: SlotMap::with_key(),
threshold,
strings: StringPool::new(),
}
}
#[inline]
pub(super) fn get(&self, ptr: ObjectPtr) -> &WrappedObject {
self.objects
.get(ptr.0)
.expect("Invalid ObjectPtr: object was freed (use-after-free detected)")
}
#[inline]
pub(super) fn get_mut(&mut self, ptr: ObjectPtr) -> &mut WrappedObject {
self.objects
.get_mut(ptr.0)
.expect("Invalid ObjectPtr: object was freed (use-after-free detected)")
}
pub(super) fn as_lua_function(&self, ptr: ObjectPtr) -> Option<Closure> {
match &self.get(ptr).raw {
RawObject::LuaFn(closure) => Some((**closure).clone()),
_ => None,
}
}
pub(super) fn as_table(&mut self, ptr: ObjectPtr) -> Option<&mut Table> {
match &mut self.get_mut(ptr).raw {
RawObject::Table(t) => Some(t),
_ => None,
}
}
pub(super) fn as_table_ref(&self, ptr: ObjectPtr) -> Option<&Table> {
match &self.get(ptr).raw {
RawObject::Table(t) => Some(t),
_ => None,
}
}
pub(super) fn get_string(&self, ptr: StringPtr) -> &[u8] {
self.strings.get(ptr)
}
#[hotpath::measure]
pub(super) fn alloc_lua_fn(
&mut self,
bytecode: Arc<Bytecode>,
upvalues: Vec<UpvalueRef>,
) -> ObjectPtr {
let caches = Arc::new(RuntimeCaches::new(&bytecode));
let closure = Closure {
bytecode,
caches,
upvalues,
};
let raw = RawObject::LuaFn(Box::new(closure));
let wrapped = WrappedObject {
raw,
color: Cell::new(Color::Unmarked),
};
ObjectPtr(self.objects.insert(wrapped))
}
#[hotpath::measure]
pub(super) fn alloc_table(&mut self) -> ObjectPtr {
let raw = RawObject::Table(Table::default());
let wrapped = WrappedObject {
raw,
color: Cell::new(Color::Unmarked),
};
ObjectPtr(self.objects.insert(wrapped))
}
#[hotpath::measure]
pub(super) fn alloc_string(&mut self, bytes: &[u8]) -> StringPtr {
let hash = StringPool::hash_string(bytes);
if let Some(ptr) = self.strings.find_by_hash(bytes, hash) {
return ptr;
}
self.strings.insert_with_hash(bytes.into(), hash)
}
#[must_use]
pub(super) fn is_full(&self) -> bool {
self.objects.len() >= self.threshold
}
#[hotpath::measure]
pub(super) fn mark(&self, ptr: ObjectPtr, upvalue_pool: &UpvaluePool) {
if let Some(obj) = self.objects.get(ptr.0)
&& obj.color.get() == Color::Unmarked
{
obj.color.set(Color::Reachable);
self.mark_children(obj, upvalue_pool);
}
}
pub(super) fn mark_string(&self, ptr: StringPtr) {
self.strings.mark(ptr);
}
#[hotpath::measure]
fn mark_children(&self, obj: &WrappedObject, upvalue_pool: &UpvaluePool) {
match &obj.raw {
RawObject::LuaFn(closure) => {
for uv_ref in &closure.upvalues {
if let Upvalue::Closed(val) = upvalue_pool.get(*uv_ref) {
val.mark_reachable(self, upvalue_pool);
}
}
}
RawObject::Table(tbl) => {
tbl.mark_values(self, upvalue_pool);
}
}
}
#[hotpath::measure(label = "object::heap_collect")]
pub(super) fn collect(&mut self) {
#[cfg(feature = "debug_gc")]
{
println!("Running garbage collector");
println!("Initial size: {}", self.objects.len());
}
self.objects.retain(|_, obj| match obj.color.get() {
Color::Reachable => {
obj.color.set(Color::Unmarked);
true
}
Color::Unmarked => false,
});
self.strings.collect();
self.threshold = (self.objects.len() * 2).max(20);
#[cfg(feature = "debug_gc")]
println!("Final size: {}", self.objects.len());
}
pub(super) fn object_count(&self) -> usize {
self.objects.len()
}
pub(super) fn string_count(&self) -> usize {
self.strings.len()
}
pub(super) fn threshold(&self) -> usize {
self.threshold
}
pub(super) fn set_threshold(&mut self, threshold: usize) {
self.threshold = threshold;
}
}
pub(super) trait Markable {
fn mark_reachable(&self, heap: &GcHeap, upvalue_pool: &UpvaluePool);
}
impl Markable for Val {
fn mark_reachable(&self, heap: &GcHeap, upvalue_pool: &UpvaluePool) {
match self {
Val::Obj(ptr) => heap.mark(*ptr, upvalue_pool),
Val::Str(ptr) => heap.mark_string(*ptr),
_ => (),
}
}
}
impl<T: Markable> Markable for [T] {
fn mark_reachable(&self, heap: &GcHeap, upvalue_pool: &UpvaluePool) {
for val in self {
val.mark_reachable(heap, upvalue_pool);
}
}
}
impl<K, V: Markable> Markable for IndexMap<K, V> {
fn mark_reachable(&self, heap: &GcHeap, upvalue_pool: &UpvaluePool) {
for val in self.values() {
val.mark_reachable(heap, upvalue_pool);
}
}
}
struct StringEntry {
data: Box<[u8]>,
hash: u64,
color: Cell<Color>,
}
#[derive(Clone, Copy, Debug, Eq, Hash, PartialEq)]
pub(crate) struct StringPtr(StringKey);
impl fmt::Display for StringPtr {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
write!(f, "string: {:?}", self.0)
}
}
pub(crate) struct StringPool {
strings: SlotMap<StringKey, StringEntry>,
hash_index: IndexMap<u64, Vec<StringKey>>,
}
impl StringPool {
fn new() -> Self {
Self {
strings: SlotMap::with_key(),
hash_index: IndexMap::new(),
}
}
pub(super) fn len(&self) -> usize {
self.strings.len()
}
pub(super) fn hash_string(bytes: &[u8]) -> u64 {
use std::hash::Hasher;
let mut hasher = std::collections::hash_map::DefaultHasher::new();
bytes.hash(&mut hasher);
hasher.finish()
}
pub(super) fn get(&self, ptr: StringPtr) -> &[u8] {
&self
.strings
.get(ptr.0)
.expect("Invalid StringPtr: string was freed (use-after-free detected)")
.data
}
#[hotpath::measure]
pub(super) fn find_by_hash(&self, bytes: &[u8], hash: u64) -> Option<StringPtr> {
let bucket = self.hash_index.get(&hash)?;
for key in bucket {
if let Some(entry) = self.strings.get(*key)
&& entry.data.as_ref() == bytes
{
return Some(StringPtr(*key));
}
}
None
}
#[hotpath::measure]
pub(super) fn insert_with_hash(&mut self, bytes: Box<[u8]>, hash: u64) -> StringPtr {
let entry = StringEntry {
data: bytes,
hash,
color: Cell::new(Color::Unmarked),
};
let key = self.strings.insert(entry);
self.hash_index.entry(hash).or_default().push(key);
StringPtr(key)
}
pub(super) fn mark(&self, ptr: StringPtr) {
if let Some(entry) = self.strings.get(ptr.0) {
entry.color.set(Color::Reachable);
}
}
#[hotpath::measure(label = "object::string_pool_collect")]
pub(super) fn collect(&mut self) {
let mut removed: Vec<(StringKey, u64)> = Vec::new();
self.strings.retain(|key, entry| match entry.color.get() {
Color::Reachable => {
entry.color.set(Color::Unmarked);
true
}
Color::Unmarked => {
removed.push((key, entry.hash));
false
}
});
for (key, hash) in removed {
if let Some(bucket) = self.hash_index.get_mut(&hash) {
bucket.retain(|k| *k != key);
if bucket.is_empty() {
self.hash_index.shift_remove(&hash);
}
}
}
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_basic_allocation() {
let mut heap = GcHeap::with_threshold(100);
let t1 = heap.alloc_table();
let t2 = heap.alloc_table();
assert!(heap.as_table_ref(t1).is_some());
assert!(heap.as_table_ref(t2).is_some());
assert_eq!(heap.object_count(), 2);
}
#[test]
fn test_gc_collect() {
let mut heap = GcHeap::with_threshold(100);
let kept = heap.alloc_table();
let _freed = heap.alloc_table();
let pool = UpvaluePool::new();
heap.mark(kept, &pool);
heap.collect();
assert!(heap.as_table_ref(kept).is_some());
assert_eq!(heap.object_count(), 1);
}
#[test]
#[should_panic(expected = "use-after-free")]
fn test_use_after_free_detection() {
let mut heap = GcHeap::with_threshold(100);
let ptr = heap.alloc_table();
heap.collect();
let _ = heap.as_table_ref(ptr);
}
#[test]
fn test_string_allocation() {
let mut heap = GcHeap::with_threshold(100);
let s1 = heap.alloc_string(b"hello");
let s2 = heap.alloc_string(b"world");
let s3 = heap.alloc_string(b"hello");
assert_eq!(heap.get_string(s1), b"hello");
assert_eq!(heap.get_string(s2), b"world");
assert_eq!(s1, s3); assert_eq!(heap.string_count(), 2);
}
#[test]
fn test_string_gc_collect() {
let mut heap = GcHeap::with_threshold(100);
let kept = heap.alloc_string(b"keep");
let _freed = heap.alloc_string(b"free");
heap.mark_string(kept);
heap.collect();
assert_eq!(heap.get_string(kept), b"keep");
assert_eq!(heap.string_count(), 1);
}
#[test]
#[should_panic(expected = "use-after-free")]
fn test_string_use_after_free_detection() {
let mut heap = GcHeap::with_threshold(100);
let ptr = heap.alloc_string(b"test");
heap.collect();
let _ = heap.get_string(ptr);
}
}