use crate::loom::*;
use parking_lot::lock_api::RawMutex;
use std::{
cell::UnsafeCell,
fmt::{Formatter, Pointer, Result},
ops::Deref,
ptr::NonNull,
sync::{Arc, Weak},
};
mod private {
pub trait Sealed {}
impl Sealed for () {}
impl<T: ?Sized + Eq + std::hash::Hash> Sealed for crate::hash::Hash<T> {}
impl<T: ?Sized + Ord> Sealed for crate::tree::Ord<T> {}
}
pub trait Interner: private::Sealed + Sized {
type T: ?Sized;
fn remove(&self, value: &Interned<Self>) -> (bool, Option<Interned<Self>>);
}
impl Interner for () {
type T = ();
fn remove(&self, _value: &Interned<Self>) -> (bool, Option<Interned<Self>>) {
(false, None)
}
}
struct State<I> {
mutex: parking_lot::RawMutex,
refs: UnsafeCell<u32>,
cleanup: UnsafeCell<Option<Weak<I>>>,
}
impl<I: Interner> State<I> {
pub fn new() -> Self {
Self {
mutex: parking_lot::RawMutex::INIT,
refs: UnsafeCell::new(1),
cleanup: UnsafeCell::new(None),
}
}
pub fn lock(&self) -> Guard<'_, I> {
self.mutex.lock();
Guard(self)
}
}
struct Guard<'a, I>(&'a State<I>);
impl<'a, I> Guard<'a, I> {
pub fn refs(&self) -> u32 {
unsafe { *self.0.refs.get() }
}
pub fn refs_mut(&mut self) -> &mut u32 {
unsafe { &mut *self.0.refs.get() }
}
pub fn cleanup(&mut self) -> &mut Option<Weak<I>> {
unsafe { &mut *self.0.cleanup.get() }
}
}
impl<'a, I> Drop for Guard<'a, I> {
fn drop(&mut self) {
unsafe { self.0.mutex.unlock() };
}
}
#[repr(C)]
struct RefCounted<I: Interner> {
state: State<I>,
value: I::T,
}
impl<I: Interner> RefCounted<I> {
fn from_box(value: Box<I::T>) -> NonNull<Self> {
let layout = Layout::new::<RefCounted<()>>()
.extend(Layout::for_value(value.as_ref()))
.unwrap() .0
.pad_to_align();
unsafe {
let ptr = alloc(layout);
let b = Box::leak(value) as *mut I::T;
let ptr = {
let mut temp = b as *mut Self;
std::ptr::write(&mut temp as *mut _ as *mut *mut u8, ptr);
temp
};
std::ptr::write(&mut (*ptr).state, State::new());
let num_bytes = std::mem::size_of_val(&*b);
if num_bytes > 0 {
std::ptr::copy_nonoverlapping(
b as *const u8,
&mut (*ptr).value as *mut _ as *mut u8,
num_bytes,
);
#[cfg(not(loom))]
dealloc(b as *mut u8, Layout::for_value(&*b));
#[cfg(loom)]
std::alloc::dealloc(b as *mut u8, Layout::for_value(&*b));
}
NonNull::new_unchecked(ptr)
}
}
fn from_sized(value: I::T) -> NonNull<Self>
where
I::T: Sized,
{
let b = Box::new(Self {
state: State::new(),
value,
});
NonNull::from(Box::leak(b))
}
}
pub struct Interned<I: Interner> {
inner: NonNull<RefCounted<I>>,
}
unsafe impl<I: Interner> Send for Interned<I> where I::T: Send + Sync + 'static {}
unsafe impl<I: Interner> Sync for Interned<I> where I::T: Send + Sync + 'static {}
impl<I: Interner> Interned<I> {
pub(crate) fn ref_count(&self) -> u32 {
self.lock().refs()
}
fn lock(&self) -> Guard<'_, I> {
unsafe { self.inner.as_ref().state.lock() }
}
pub(crate) fn from_box(value: Box<I::T>) -> Self {
Self {
inner: RefCounted::from_box(value),
}
}
pub(crate) fn from_sized(value: I::T) -> Self
where
I::T: Sized,
{
Self {
inner: RefCounted::from_sized(value),
}
}
pub(crate) fn make_hot(&mut self, set: &Arc<I>) {
let mut state = self.lock();
*state.cleanup() = Some(Arc::downgrade(set));
}
}
const MAX_REFCOUNT: u32 = u32::MAX - 2;
impl<I: Interner> Clone for Interned<I> {
fn clone(&self) -> Self {
let refs = {
let mut state = self.lock();
*state.refs_mut() += 1;
state.refs()
};
if refs > MAX_REFCOUNT {
panic!("either you are running on an 8086 or you are leaking Interned values at a phantastic rate");
}
let ret = Self { inner: self.inner };
#[cfg(feature = "println")]
println!("{:?} clone {:p}", current().id(), *self);
ret
}
}
impl<I: Interner> Drop for Interned<I> {
fn drop(&mut self) {
#[cfg(feature = "println")]
println!("{:?} dropping {:p} {:p}", current().id(), self, *self);
let mut state = self.lock();
#[cfg(feature = "println")]
println!(
"{:?} read {} {:p} {:p}",
current().id(),
state.refs(),
self,
*self
);
*state.refs_mut() -= 1;
if state.refs() > 1 {
return;
}
if state.refs() == 1 {
if let Some(cleanup) = state.cleanup().take() {
#[cfg(feature = "println")]
println!("{:?} removing {:p} {:p}", current().id(), self, *self);
if let Some(strong) = cleanup.upgrade() {
drop(state);
loop {
let (removed, _value) = strong.remove(self);
if removed {
break;
} else {
let mut state = self.lock();
if state.refs() > 1 {
*state.cleanup() = Some(cleanup);
break;
} else {
drop(state);
}
}
}
} else {
}
#[cfg(feature = "println")]
println!("{:?} removed {:p}", current().id(), self);
} else {
#[cfg(feature = "println")]
println!("{:?} cleanup gone {:p}", current().id(), self);
}
} else if state.refs() == 0 {
#[cfg(feature = "println")]
println!("{:?} drop {:p} {:p}", current().id(), self, *self);
drop(state);
drop(unsafe { Box::from_raw(self.inner.as_ptr()) });
}
#[cfg(feature = "println")]
println!("{:?} dropend {:p}", current().id(), self);
}
}
impl<I: Interner> PartialEq for Interned<I> {
fn eq(&self, other: &Self) -> bool {
std::ptr::eq(self.inner.as_ptr(), other.inner.as_ptr())
}
}
impl<I: Interner> Deref for Interned<I> {
type Target = I::T;
fn deref(&self) -> &Self::Target {
&unsafe { self.inner.as_ref() }.value
}
}
impl<I: Interner> Pointer for Interned<I> {
fn fmt(&self, f: &mut Formatter<'_>) -> Result {
Pointer::fmt(&(&**self as *const I::T), f)
}
}
#[cfg(all(test, not(loom)))]
mod tests {
use super::*;
use crate::OrdInterner;
#[test]
fn pointer() {
let interner = OrdInterner::new();
let i = interner.intern_sized(42);
let i2 = i.clone();
assert_eq!(format!("{:p}", i), format!("{:p}", i2));
}
#[test]
fn size() {
use std::mem::size_of;
const SIZE: usize = if size_of::<usize>() == 4 { 12 } else { 16 };
assert_eq!(size_of::<RefCounted<()>>(), SIZE);
let fake = RefCounted::<crate::hash::Hash<i32>> {
state: State::new(),
value: 42,
};
println!("base: {:p}", &fake);
let base = &fake as *const _ as *const u8;
println!("state: {:p} (base + {})", &fake.state, unsafe {
(&fake.state as *const _ as *const u8).offset_from(base)
});
println!("mutex: {:p} (base + {})", &fake.state.mutex, unsafe {
(&fake.state.mutex as *const _ as *const u8).offset_from(base)
});
println!("refs: {:p} (base + {})", &fake.state.refs, unsafe {
(&fake.state.refs as *const _ as *const u8).offset_from(base)
});
println!("clean: {:p} (base + {})", &fake.state.cleanup, unsafe {
(&fake.state.cleanup as *const _ as *const u8).offset_from(base)
});
println!("value: {:p} (base + {})", &fake.value, unsafe {
(&fake.value as *const _ as *const u8).offset_from(base)
});
}
}