use core::{
fmt::Debug,
mem::MaybeUninit,
ops::{Deref, DerefMut},
ptr::{self, NonNull},
sync::atomic::{AtomicUsize, Ordering},
};
#[repr(transparent)]
pub struct Shared<T>(NonNull<Inner<T>>);
struct Inner<T> {
data: T,
count: AtomicUsize,
}
#[repr(transparent)]
pub struct Token<T>(NonNull<Inner<T>>);
impl<T> Token<T> {
pub fn shared(&self) -> Shared<T> {
unsafe { self.0.as_ref().count.fetch_add(1, Ordering::Relaxed) };
Shared(self.0)
}
pub fn count(&self) -> usize {
unsafe { self.0.as_ref().count.load(Ordering::Relaxed) }
}
}
impl<T> Shared<T> {
pub fn new(data: T) -> Shared<T> {
Self(
Box::leak(Box::new(Inner {
data,
count: AtomicUsize::new(1),
}))
.into(),
)
}
pub fn new_cyclic<F>(data_fn: F) -> Self
where
F: FnOnce(Token<T>) -> T,
{
let uninit_ptr: NonNull<_> = Box::leak(Box::new(Inner {
data: MaybeUninit::<T>::uninit(),
count: AtomicUsize::new(0),
}))
.into();
let mut init_ptr = uninit_ptr.cast();
unsafe {
let data = data_fn(Token(init_ptr));
ptr::write(ptr::addr_of_mut!(init_ptr.as_mut().data), data);
init_ptr.as_mut().count.fetch_add(1, Ordering::Relaxed);
Self(init_ptr)
}
}
#[inline]
fn inner(&self) -> &Inner<T> {
unsafe { self.0.as_ref() }
}
#[inline]
fn inner_mut(&mut self) -> &mut Inner<T> {
unsafe { self.0.as_mut() }
}
}
impl<T> Clone for Shared<T> {
fn clone(&self) -> Self {
self.inner().count.fetch_add(1, Ordering::Relaxed);
Self(self.0)
}
}
impl<T> Debug for Shared<T> {
fn fmt(&self, f: &mut core::fmt::Formatter<'_>) -> core::fmt::Result {
write!(f, "Shared")
}
}
impl<T: Default> Default for Shared<T> {
fn default() -> Self {
Self::new(T::default())
}
}
impl<T> Deref for Shared<T> {
type Target = T;
fn deref(&self) -> &Self::Target {
&self.inner().data
}
}
impl<T> DerefMut for Shared<T> {
fn deref_mut(&mut self) -> &mut Self::Target {
&mut self.inner_mut().data
}
}
impl<T> Drop for Shared<T> {
fn drop(&mut self) {
if self.inner_mut().count.fetch_sub(1, Ordering::Release) == 1 {
unsafe { ptr::drop_in_place(self.0.as_ptr()) }
}
}
}
unsafe impl<T> Send for Shared<T> {}
unsafe impl<T> Sync for Shared<T> {}
#[cfg(test)]
mod tests {
use super::*;
struct Foo {
bar: Bar,
}
struct Bar {
data: Vec<u8>,
foo: Shared<Foo>,
}
#[test]
fn example() {
let mut cyclic = Shared::new_cyclic(|weak| Foo {
bar: Bar {
data: vec![0, 1, 2],
foo: weak.shared(),
},
});
cyclic.bar.data.push(4);
assert_eq!(cyclic.bar.data.len(), 4);
assert_eq!(cyclic.bar.data, cyclic.bar.foo.bar.data);
}
}