use crate::ArcShift;
use core::cell::{Cell, UnsafeCell};
use core::fmt::{Debug, Display, Formatter};
use core::marker::PhantomData;
use core::ops::Deref;
pub struct ArcShiftCell<T: 'static + ?Sized> {
inner: UnsafeCell<ArcShift<T>>,
recursion: Cell<usize>,
}
pub struct ArcShiftCellHandle<'a, T: 'static + ?Sized> {
cell: &'a ArcShiftCell<T>,
_marker: PhantomData<*mut T>,
}
impl<T: 'static + ?Sized> Drop for ArcShiftCellHandle<'_, T> {
fn drop(&mut self) {
let mut rec = self.cell.recursion.get();
rec -= 1;
if rec == 0 {
unsafe { (*self.cell.inner.get()).reload() };
}
self.cell.recursion.set(rec);
}
}
impl<T: 'static + ?Sized> Deref for ArcShiftCellHandle<'_, T> {
type Target = T;
fn deref(&self) -> &Self::Target {
if self.cell.recursion.get() == 1 {
let inner: &mut ArcShift<T> = unsafe { &mut *self.cell.inner.get() };
inner.get()
} else {
let inner: &ArcShift<T> = unsafe { &*self.cell.inner.get() };
inner.shared_non_reloading_get()
}
}
}
unsafe impl<T: 'static> Send for ArcShiftCell<T> where T: Send + Sync {}
impl<T: 'static + ?Sized> Clone for ArcShiftCell<T> {
fn clone(&self) -> Self {
let clone = unsafe { &mut *self.inner.get() }.clone();
ArcShiftCell {
inner: UnsafeCell::new(clone),
recursion: Cell::new(0),
}
}
}
pub struct RecursionDetected;
impl Debug for RecursionDetected {
fn fmt(&self, f: &mut Formatter<'_>) -> core::fmt::Result {
write!(f, "RecursionDetected")
}
}
impl Display for RecursionDetected {
fn fmt(&self, f: &mut Formatter<'_>) -> core::fmt::Result {
write!(f, "RecursionDetected")
}
}
impl<T: 'static> ArcShiftCell<T> {
pub fn new(value: T) -> ArcShiftCell<T> {
ArcShiftCell::from_arcshift(ArcShift::new(value))
}
}
impl<T: 'static + ?Sized> ArcShiftCell<T> {
pub fn from_arcshift(input: ArcShift<T>) -> ArcShiftCell<T> {
ArcShiftCell {
inner: UnsafeCell::new(input),
recursion: Cell::new(0),
}
}
#[inline]
pub fn borrow(&self) -> ArcShiftCellHandle<'_, T> {
self.recursion.set(self.recursion.get() + 1);
ArcShiftCellHandle {
cell: self,
_marker: PhantomData,
}
}
#[inline]
pub fn get<R>(&self, f: impl FnOnce(&T) -> R) -> R {
self.recursion.set(self.recursion.get() + 1);
let val = if self.recursion.get() == 1 {
unsafe { &mut *self.inner.get() }.get()
} else {
unsafe { &*self.inner.get() }.shared_non_reloading_get()
};
let t = f(val);
self.recursion.set(self.recursion.get() - 1);
t
}
pub fn assign(&self, other: &ArcShift<T>) -> Result<(), RecursionDetected> {
if self.recursion.get() == 0 {
*unsafe { &mut *self.inner.get() } = other.clone();
Ok(())
} else {
Err(RecursionDetected)
}
}
pub fn reload(&self) {
if self.recursion.get() == 0 {
unsafe { &mut *self.inner.get() }.reload()
}
}
pub fn make_arcshift(&self) -> ArcShift<T> {
unsafe { &mut *self.inner.get() }.clone()
}
}
impl<T: ?Sized + 'static> Debug for ArcShiftCell<T>
where
T: Debug,
{
fn fmt(&self, f: &mut Formatter<'_>) -> core::fmt::Result {
write!(f, "ArcShiftCell({:?})", &*self.borrow())
}
}