use crate::heap::{Handle, Heap};
use alloc::collections::BTreeSet;
use alloc::vec::Vec;
pub trait Trace {
fn trace(&self, visit: &mut dyn FnMut(Handle));
}
pub trait Relocate {
fn relocate(&mut self, forward: &dyn Fn(Handle) -> Handle);
}
#[derive(Clone, Copy, PartialEq, Eq, Debug, Default)]
pub struct Stats {
pub marked: usize,
pub swept: usize,
}
pub const OLD_AGE: u8 = 1;
fn mark<T: Trace>(heap: &Heap<T>, roots: impl IntoIterator<Item = Handle>) -> BTreeSet<Handle> {
let mut marked: BTreeSet<Handle> = BTreeSet::new();
let mut work: Vec<Handle> = Vec::new();
for root in roots {
if heap.is_live(root) && marked.insert(root) {
work.push(root);
}
}
while let Some(handle) = work.pop() {
let mut edges: Vec<Handle> = Vec::new();
if let Some(obj) = heap.get(handle) {
obj.trace(&mut |h| edges.push(h));
}
for edge in edges {
if heap.is_live(edge) && marked.insert(edge) {
work.push(edge);
}
}
}
marked
}
pub fn collect<T: Trace>(heap: &mut Heap<T>, roots: &[Handle]) -> Stats {
let marked = mark(heap, roots.iter().copied());
let mut swept = 0;
for handle in heap.live_handles() {
if marked.contains(&handle) {
heap.tenure(handle); } else {
heap.free(handle);
swept += 1;
}
}
heap.clear_remembered();
Stats {
marked: marked.len(),
swept,
}
}
pub fn collect_minor<T: Trace>(heap: &mut Heap<T>, roots: &[Handle]) -> Stats {
let remembered = heap.remembered_roots();
let marked = mark(heap, roots.iter().copied().chain(remembered));
let mut swept = 0;
for handle in heap.handles_where(|a| a < OLD_AGE) {
if marked.contains(&handle) {
heap.tenure(handle);
} else {
heap.free(handle);
swept += 1;
}
}
heap.clear_remembered();
Stats {
marked: marked.len(),
swept,
}
}
pub fn compact<T: Trace + Relocate>(heap: &mut Heap<T>, roots: &mut [Handle]) -> Stats {
compact_with(heap, roots, &mut |_| {})
}
#[allow(clippy::type_complexity)] pub fn compact_with<T: Trace + Relocate>(
heap: &mut Heap<T>,
roots: &mut [Handle],
fixup: &mut dyn FnMut(&dyn Fn(Handle) -> Handle),
) -> Stats {
let before = heap.len();
let marked = mark(heap, roots.iter().copied());
let map: alloc::collections::BTreeMap<Handle, Handle> =
heap.compact_to(&marked).into_iter().collect();
let forward = |h: Handle| map.get(&h).copied().unwrap_or(h);
relocate(heap, &forward);
for r in roots.iter_mut() {
*r = forward(*r);
}
fixup(&forward);
Stats {
marked: marked.len(),
swept: before - marked.len(),
}
}
pub fn relocate<T: Relocate>(heap: &mut Heap<T>, forward: &dyn Fn(Handle) -> Handle) {
for handle in heap.live_handles() {
if let Some(obj) = heap.get_mut(handle) {
obj.relocate(forward);
}
}
}
pub struct IncrementalMarker {
marked: BTreeSet<Handle>,
grey: Vec<Handle>,
}
impl IncrementalMarker {
#[must_use]
pub fn new(roots: &[Handle]) -> Self {
let mut m = Self {
marked: BTreeSet::new(),
grey: Vec::new(),
};
for &r in roots {
m.mark_grey(r);
}
m
}
pub fn mark_grey(&mut self, handle: Handle) {
if self.marked.insert(handle) {
self.grey.push(handle);
}
}
pub fn step<T: Trace>(&mut self, heap: &Heap<T>, budget: usize) -> bool {
for _ in 0..budget {
let Some(handle) = self.grey.pop() else {
return true;
};
let mut edges: Vec<Handle> = Vec::new();
if let Some(obj) = heap.get(handle) {
obj.trace(&mut |h| edges.push(h));
}
for edge in edges {
if heap.is_live(edge) {
self.mark_grey(edge);
}
}
}
self.grey.is_empty()
}
#[must_use]
pub fn is_complete(&self) -> bool {
self.grey.is_empty()
}
pub fn sweep<T>(&self, heap: &mut Heap<T>) -> usize {
let mut swept = 0;
for handle in heap.live_handles() {
if !self.marked.contains(&handle) {
heap.free(handle);
swept += 1;
}
}
swept
}
}
#[cfg(test)]
mod tests {
use super::*;
struct Node {
tag: u32,
edges: Vec<Handle>,
}
impl Node {
fn new(tag: u32) -> Self {
Self {
tag,
edges: Vec::new(),
}
}
}
impl Trace for Node {
fn trace(&self, visit: &mut dyn FnMut(Handle)) {
for &e in &self.edges {
visit(e);
}
}
}
impl Relocate for Node {
fn relocate(&mut self, forward: &dyn Fn(Handle) -> Handle) {
for e in &mut self.edges {
*e = forward(*e);
}
}
}
#[test]
fn unreachable_objects_are_swept() {
let mut heap: Heap<Node> = Heap::new();
let keep = heap.alloc(Node::new(1));
let drop = heap.alloc(Node::new(2));
assert_eq!(heap.len(), 2);
let stats = collect(&mut heap, &[keep]);
assert_eq!(stats.marked, 1);
assert_eq!(stats.swept, 1);
assert!(heap.is_live(keep));
assert!(!heap.is_live(drop));
assert_eq!(heap.len(), 1);
assert_eq!(heap.get(keep).unwrap().tag, 1);
}
#[test]
fn reachable_chain_is_kept() {
let mut heap: Heap<Node> = Heap::new();
let c = heap.alloc(Node::new(3));
let mut b_node = Node::new(2);
b_node.edges.push(c);
let b = heap.alloc(b_node);
let mut a_node = Node::new(1);
a_node.edges.push(b);
let a = heap.alloc(a_node);
let _d = heap.alloc(Node::new(4));
let stats = collect(&mut heap, &[a]);
assert_eq!(stats.marked, 3);
assert_eq!(stats.swept, 1);
assert!(heap.is_live(a) && heap.is_live(b) && heap.is_live(c));
assert_eq!(heap.len(), 3);
}
#[test]
fn cycles_among_garbage_are_collected() {
let mut heap: Heap<Node> = Heap::new();
let x = heap.alloc(Node::new(1));
let y = heap.alloc(Node::new(2));
heap.get_mut(x).unwrap().edges.push(y);
heap.get_mut(y).unwrap().edges.push(x);
let survivor = heap.alloc(Node::new(3));
let stats = collect(&mut heap, &[survivor]);
assert_eq!(stats.swept, 2);
assert_eq!(stats.marked, 1);
assert!(!heap.is_live(x) && !heap.is_live(y));
assert!(heap.is_live(survivor));
}
#[test]
fn major_collection_promotes_survivors() {
let mut heap: Heap<Node> = Heap::new();
let keep = heap.alloc(Node::new(1));
assert_eq!(heap.age(keep), Some(0)); collect(&mut heap, &[keep]);
assert_eq!(heap.age(keep), Some(OLD_AGE)); }
#[test]
fn minor_collection_sweeps_only_the_young() {
let mut heap: Heap<Node> = Heap::new();
let old = heap.alloc(Node::new(1));
collect(&mut heap, &[old]);
assert_eq!(heap.age(old), Some(OLD_AGE));
let young_keep = heap.alloc(Node::new(2));
let young_garbage = heap.alloc(Node::new(3));
let stats = collect_minor(&mut heap, &[old, young_keep]);
assert_eq!(stats.swept, 1);
assert!(heap.is_live(old) && heap.is_live(young_keep));
assert!(!heap.is_live(young_garbage));
assert_eq!(heap.age(young_keep), Some(OLD_AGE)); }
#[test]
fn minor_collection_keeps_young_referenced_by_old() {
let mut heap: Heap<Node> = Heap::new();
let old = heap.alloc(Node::new(1));
collect(&mut heap, &[old]);
let young = heap.alloc(Node::new(2));
heap.get_mut(old).unwrap().edges.push(young); heap.record_edge(old, young, OLD_AGE);
let stats = collect_minor(&mut heap, &[old]);
assert_eq!(stats.swept, 0);
assert!(heap.is_live(young));
}
#[test]
fn minor_collection_frees_young_when_no_barrier_recorded() {
let mut heap: Heap<Node> = Heap::new();
let old = heap.alloc(Node::new(1));
collect(&mut heap, &[old]);
let young = heap.alloc(Node::new(2)); let stats = collect_minor(&mut heap, &[old]);
assert_eq!(stats.swept, 1);
assert!(!heap.is_live(young));
}
#[test]
fn compaction_relocates_survivors_and_fixes_references() {
let mut heap: Heap<Node> = Heap::new();
let b = heap.alloc(Node::new(2));
let _c = heap.alloc(Node::new(3)); let mut a_node = Node::new(1);
a_node.edges.push(b);
let a = heap.alloc(a_node);
let _d = heap.alloc(Node::new(4));
let mut roots = [a];
let stats = compact(&mut heap, &mut roots);
assert_eq!(stats.marked, 2);
assert_eq!(stats.swept, 2);
let a2 = roots[0];
assert_eq!(heap.get(a2).unwrap().tag, 1);
let b2 = heap.get(a2).unwrap().edges[0];
assert_eq!(heap.get(b2).unwrap().tag, 2);
assert_eq!(heap.len(), 2);
assert_eq!(heap.live_handles().len(), 2);
}
#[test]
fn compaction_preserves_a_reachable_cycle() {
let mut heap: Heap<Node> = Heap::new();
let x = heap.alloc(Node::new(1));
let y = heap.alloc(Node::new(2));
heap.get_mut(x).unwrap().edges.push(y);
heap.get_mut(y).unwrap().edges.push(x);
let _garbage = heap.alloc(Node::new(9));
let mut roots = [x];
let stats = compact(&mut heap, &mut roots);
assert_eq!(stats.marked, 2);
assert_eq!(stats.swept, 1);
let x2 = roots[0];
let y2 = heap.get(x2).unwrap().edges[0];
let back = heap.get(y2).unwrap().edges[0];
assert_eq!(back, x2);
}
#[test]
fn incremental_marking_matches_a_full_collection() {
let mut heap: Heap<Node> = Heap::new();
let c = heap.alloc(Node::new(3));
let mut b = Node::new(2);
b.edges.push(c);
let b = heap.alloc(b);
let mut a = Node::new(1);
a.edges.push(b);
let a = heap.alloc(a);
let _d = heap.alloc(Node::new(4));
let _e = heap.alloc(Node::new(5));
let mut marker = IncrementalMarker::new(&[a]);
let mut steps = 0;
while !marker.step(&heap, 1) {
steps += 1;
assert!(steps < 100, "marking should terminate");
}
let swept = marker.sweep(&mut heap);
assert_eq!(swept, 2); assert!(heap.is_live(a) && heap.is_live(b) && heap.is_live(c));
assert_eq!(heap.len(), 3);
}
#[test]
fn incremental_write_barrier_keeps_a_late_stored_reference() {
let mut heap: Heap<Node> = Heap::new();
let root = heap.alloc(Node::new(1)); let late = heap.alloc(Node::new(2));
let mut marker = IncrementalMarker::new(&[root]);
while !marker.step(&heap, 1) {}
assert!(marker.is_complete());
heap.get_mut(root).unwrap().edges.push(late);
marker.mark_grey(late); while !marker.step(&heap, 4) {}
let swept = marker.sweep(&mut heap);
assert_eq!(swept, 0, "the barrier-shaded object must survive");
assert!(heap.is_live(late));
}
#[test]
fn reachable_cycle_survives() {
let mut heap: Heap<Node> = Heap::new();
let x = heap.alloc(Node::new(1));
let y = heap.alloc(Node::new(2));
heap.get_mut(x).unwrap().edges.push(y);
heap.get_mut(y).unwrap().edges.push(x);
let stats = collect(&mut heap, &[x]);
assert_eq!(stats.marked, 2);
assert_eq!(stats.swept, 0);
assert!(heap.is_live(x) && heap.is_live(y));
}
#[test]
fn snapshot_reload_relocates_pointers() {
use alloc::collections::BTreeMap;
let mut a: Heap<Node> = Heap::new();
let n0 = a.alloc(Node::new(10));
let n1 = a.alloc(Node::new(11));
let n2 = a.alloc(Node::new(12));
a.get_mut(n0).unwrap().edges.push(n1);
a.get_mut(n1).unwrap().edges.push(n0);
a.get_mut(n1).unwrap().edges.push(n2);
let live = a.live_handles();
let snap: Vec<(Handle, u32, Vec<Handle>)> = live
.iter()
.map(|h| {
let mut edges = Vec::new();
a.get(*h).unwrap().trace(&mut |e| edges.push(e));
(*h, a.get(*h).unwrap().tag, edges)
})
.collect();
let mut b: Heap<Node> = Heap::new();
let _decoy = b.alloc(Node::new(99));
let mut map: BTreeMap<Handle, Handle> = BTreeMap::new();
let mut new_handles = Vec::new();
for (old, tag, edges) in &snap {
let mut node = Node::new(*tag);
node.edges = edges.clone(); let nh = b.alloc(node);
map.insert(*old, nh);
new_handles.push(nh);
}
let forward = |h: Handle| map.get(&h).copied().unwrap_or(h);
relocate(&mut b, &forward);
let (b0, b1, b2) = (new_handles[0], new_handles[1], new_handles[2]);
assert_eq!(b.get(b0).unwrap().tag, 10);
assert_eq!(b.get(b0).unwrap().edges, alloc::vec![b1]);
assert_eq!(b.get(b1).unwrap().edges, alloc::vec![b0, b2]);
assert_eq!(b.get(b2).unwrap().tag, 12);
assert_ne!(b0, n0);
}
#[test]
fn empty_roots_sweeps_everything() {
let mut heap: Heap<Node> = Heap::new();
heap.alloc(Node::new(1));
heap.alloc(Node::new(2));
let stats = collect(&mut heap, &[]);
assert_eq!(stats.swept, 2);
assert_eq!(stats.marked, 0);
assert!(heap.is_empty());
}
}