pub mod trace;
use std::{
cmp::{PartialEq, Eq},
rc::Rc,
hash::{Hash, Hasher},
};
use hashbrown::{HashMap, HashSet};
use crate::trace::*;
pub mod prelude {
pub use super::{
Heap,
Handle,
Rooted,
trace::{Trace, Tracer},
};
}
type Generation = usize;
pub struct Heap<T> {
last_sweep: usize,
object_sweeps: HashMap<Handle<T>, usize>,
obj_counter: Generation,
objects: HashSet<Handle<T>>,
rooted: HashMap<Handle<T>, Rc<()>>,
}
impl<T> Default for Heap<T> {
fn default() -> Self {
Self {
last_sweep: 0,
object_sweeps: HashMap::default(),
obj_counter: 0,
objects: HashSet::default(),
rooted: HashMap::default(),
}
}
}
impl<T: Trace<T>> Heap<T> {
pub fn new() -> Self {
Self::default()
}
fn new_generation(&mut self) -> Generation {
self.obj_counter += 1;
self.obj_counter
}
pub fn insert_temp(&mut self, object: T) -> Handle<T> {
let ptr = Box::into_raw(Box::new(object));
let gen = self.new_generation();
let handle = Handle { gen, ptr };
self.objects.insert(handle);
handle
}
pub fn insert(&mut self, object: T) -> Rooted<T> {
let handle = self.insert_temp(object);
let rc = Rc::new(());
self.rooted.insert(handle, rc.clone());
Rooted {
rc,
handle,
}
}
pub fn make_rooted(&mut self, handle: impl AsRef<Handle<T>>) -> Rooted<T> {
let handle = handle.as_ref();
debug_assert!(self.contains(handle));
Rooted {
rc: self.rooted
.entry(*handle)
.or_insert_with(|| Rc::new(()))
.clone(),
handle: *handle,
}
}
pub fn len(&self) -> usize {
self.objects.len()
}
pub fn contains(&self, handle: impl AsRef<Handle<T>>) -> bool {
let handle = handle.as_ref();
self.objects.contains(&handle)
}
pub fn get(&self, handle: impl AsRef<Handle<T>>) -> Option<&T> {
let handle = handle.as_ref();
if self.contains(handle) {
Some(unsafe { &*handle.ptr })
} else {
None
}
}
pub unsafe fn get_unchecked(&self, handle: impl AsRef<Handle<T>>) -> &T {
let handle = handle.as_ref();
debug_assert!(self.contains(handle));
&*handle.ptr
}
pub fn get_mut(&mut self, handle: impl AsRef<Handle<T>>) -> Option<&mut T> {
let handle = handle.as_ref();
if self.contains(handle) {
Some(unsafe { &mut *handle.ptr })
} else {
None
}
}
pub unsafe fn get_mut_unchecked(&mut self, handle: impl AsRef<Handle<T>>) -> &mut T {
let handle = handle.as_ref();
debug_assert!(self.contains(handle));
&mut *handle.ptr
}
pub fn clean_excluding(&mut self, excluding: impl IntoIterator<Item=Handle<T>>) {
let new_sweep = self.last_sweep + 1;
let mut tracer = Tracer {
new_sweep,
object_sweeps: &mut self.object_sweeps,
objects: &self.objects,
};
self.rooted
.retain(|handle, rc| {
if Rc::strong_count(rc) > 1 {
tracer.mark(*handle);
unsafe { (&*handle.ptr).trace(&mut tracer); }
true
} else {
false
}
});
let objects = &self.objects;
excluding
.into_iter()
.filter(|handle| objects.contains(&handle))
.for_each(|handle| {
tracer.mark(handle);
unsafe { (&*handle.ptr).trace(&mut tracer); }
});
let object_sweeps = &mut self.object_sweeps;
self.objects
.retain(|handle| {
if object_sweeps
.get(handle)
.map(|sweep| *sweep == new_sweep)
.unwrap_or(false)
{
true
} else {
object_sweeps.remove(handle);
drop(unsafe { Box::from_raw(handle.ptr) });
false
}
});
self.last_sweep = new_sweep;
}
pub fn clean(&mut self) {
self.clean_excluding(std::iter::empty());
}
}
impl<T> Drop for Heap<T> {
fn drop(&mut self) {
for handle in &self.objects {
drop(unsafe { Box::from_raw(handle.ptr) });
}
}
}
#[derive(Debug)]
pub struct Handle<T> {
gen: Generation,
ptr: *mut T,
}
impl<T> Handle<T> {
pub unsafe fn get_unchecked(&self) -> &T {
&*self.ptr
}
pub unsafe fn get_mut_unchecked(&self) -> &mut T {
&mut *self.ptr
}
}
impl<T> Copy for Handle<T> {}
impl<T> Clone for Handle<T> {
fn clone(&self) -> Self {
Self { gen: self.gen, ptr: self.ptr }
}
}
impl<T> PartialEq<Self> for Handle<T> {
fn eq(&self, other: &Self) -> bool {
self.gen == other.gen && self.ptr == other.ptr
}
}
impl<T> Eq for Handle<T> {}
impl<T> Hash for Handle<T> {
fn hash<H: Hasher>(&self, state: &mut H) {
self.gen.hash(state);
self.ptr.hash(state);
}
}
impl<T> AsRef<Handle<T>> for Handle<T> {
fn as_ref(&self) -> &Handle<T> {
self
}
}
impl<T> From<Rooted<T>> for Handle<T> {
fn from(rooted: Rooted<T>) -> Self {
rooted.handle
}
}
#[derive(Debug)]
pub struct Rooted<T> {
rc: Rc<()>,
handle: Handle<T>,
}
impl<T> Clone for Rooted<T> {
fn clone(&self) -> Self {
Self {
rc: self.rc.clone(),
handle: self.handle,
}
}
}
impl<T> AsRef<Handle<T>> for Rooted<T> {
fn as_ref(&self) -> &Handle<T> {
&self.handle
}
}
impl<T> Rooted<T> {
pub fn into_handle(self) -> Handle<T> {
self.handle
}
pub fn handle(&self) -> Handle<T> {
self.handle
}
}
#[cfg(test)]
mod tests {
use super::*;
use std::sync::atomic::{AtomicUsize, Ordering};
enum Value<'a> {
Base(&'a AtomicUsize),
Refs(&'a AtomicUsize, Handle<Value<'a>>, Handle<Value<'a>>),
}
impl<'a> Trace<Self> for Value<'a> {
fn trace(&self, tracer: &mut Tracer<Self>) {
match self {
Value::Base(_) => {},
Value::Refs(_, a, b) => {
a.trace(tracer);
b.trace(tracer);
},
}
}
}
impl<'a> Drop for Value<'a> {
fn drop(&mut self) {
match self {
Value::Base(count) | Value::Refs(count, _, _) =>
count.fetch_sub(1, Ordering::Relaxed),
};
}
}
#[test]
fn basic() {
let count: AtomicUsize = AtomicUsize::new(0);
let new_count = || {
count.fetch_add(1, Ordering::Relaxed);
&count
};
let mut heap = Heap::default();
let a = heap.insert(Value::Base(new_count()));
heap.clean();
assert_eq!(heap.contains(&a), true);
let a = a.into_handle();
heap.clean();
assert_eq!(heap.contains(&a), false);
drop(heap);
assert_eq!(count.load(Ordering::Acquire), 0);
}
#[test]
fn ownership() {
let count: AtomicUsize = AtomicUsize::new(0);
let new_count = || {
count.fetch_add(1, Ordering::Relaxed);
&count
};
let mut heap = Heap::default();
let a = heap.insert(Value::Base(new_count())).handle();
let b = heap.insert(Value::Base(new_count())).handle();
let c = heap.insert(Value::Base(new_count())).handle();
let d = heap.insert(Value::Refs(new_count(), a, c));
let e = heap.insert(Value::Base(new_count())).handle();
heap.clean();
assert_eq!(heap.contains(&a), true);
assert_eq!(heap.contains(&b), false);
assert_eq!(heap.contains(&c), true);
assert_eq!(heap.contains(&d), true);
assert_eq!(heap.contains(&e), false);
let a = heap.insert_temp(Value::Base(new_count()));
heap.clean();
assert_eq!(heap.contains(&a), false);
let a = heap.insert_temp(Value::Base(new_count()));
let a = heap.make_rooted(a);
heap.clean();
assert_eq!(heap.contains(&a), true);
drop(heap);
assert_eq!(count.load(Ordering::Acquire), 0);
}
#[test]
fn recursive() {
let count: AtomicUsize = AtomicUsize::new(0);
let new_count = || {
count.fetch_add(1, Ordering::Relaxed);
&count
};
let mut heap = Heap::default();
let a = heap.insert(Value::Base(new_count()));
let b = heap.insert(Value::Base(new_count()));
*heap.get_mut(&a).unwrap() = Value::Refs(new_count(), a.handle(), b.handle());
heap.clean();
assert_eq!(heap.contains(&a), true);
assert_eq!(heap.contains(&b), true);
let a = a.into_handle();
heap.clean();
assert_eq!(heap.contains(&a), false);
assert_eq!(heap.contains(&b), true);
drop(heap);
assert_eq!(count.load(Ordering::Acquire), 0);
}
#[test]
fn temporary() {
let count: AtomicUsize = AtomicUsize::new(0);
let new_count = || {
count.fetch_add(1, Ordering::Relaxed);
&count
};
let mut heap = Heap::default();
let a = heap.insert_temp(Value::Base(new_count()));
heap.clean();
assert_eq!(heap.contains(&a), false);
let a = heap.insert_temp(Value::Base(new_count()));
let b = heap.insert(Value::Refs(new_count(), a, a));
heap.clean();
assert_eq!(heap.contains(&a), true);
assert_eq!(heap.contains(&b), true);
let a = heap.insert_temp(Value::Base(new_count()));
heap.clean_excluding(Some(a));
assert_eq!(heap.contains(&a), true);
drop(heap);
assert_eq!(count.load(Ordering::Acquire), 0);
}
}