use std::{fmt::Debug, marker::PhantomData};
pub struct MutInPlaceCell<T> {
cell: std::cell::UnsafeCell<T>,
recursion_check: recursion_check::RecursionCheck,
_marker: PhantomData<*const T>,
}
unsafe impl<T: Send> Send for MutInPlaceCell<T> {}
impl<T: Default> Default for MutInPlaceCell<T> {
fn default() -> Self {
Self::new(Default::default())
}
}
impl<T: Debug> Debug for MutInPlaceCell<T> {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
self.use_mut(|item| item.fmt(f))
}
}
impl<T> MutInPlaceCell<T> {
pub fn new(value: T) -> Self {
Self {
cell: std::cell::UnsafeCell::new(value),
recursion_check: recursion_check::RecursionCheck::new(),
_marker: PhantomData,
}
}
#[inline(always)]
pub fn use_mut<U>(&self, f: impl FnOnce(&mut T) -> U) -> U {
let _guard = self.recursion_check.enter();
let value = unsafe { &mut *self.cell.get() };
f(value)
}
}
#[cfg(debug_assertions)]
mod recursion_check {
pub struct RecursionCheck(std::cell::Cell<bool>);
impl RecursionCheck {
pub fn new() -> Self {
Self(std::cell::Cell::new(false))
}
pub fn enter(&self) -> RecursionCheckGuard<'_> {
assert!(!self.0.get(), "failed recursion check");
self.0.set(true);
RecursionCheckGuard {
recursion_check: self,
}
}
fn leave(&self) {
self.0.set(false);
}
}
pub struct RecursionCheckGuard<'a> {
recursion_check: &'a RecursionCheck,
}
impl Drop for RecursionCheckGuard<'_> {
fn drop(&mut self) {
self.recursion_check.leave();
}
}
}
#[cfg(not(debug_assertions))]
mod recursion_check {
pub struct RecursionCheck;
impl RecursionCheck {
pub fn new() -> Self {
Self
}
pub fn enter(&self) {}
}
}
#[cfg(test)]
mod test {
use std::panic::AssertUnwindSafe;
#[test]
#[should_panic]
#[cfg(debug_assertions)]
fn panic_on_recurse_test() {
let x = super::MutInPlaceCell::new(0i32);
x.use_mut(|_| x.use_mut(|_| ()));
}
#[test]
fn use_after_panic_test() {
let x = super::MutInPlaceCell::new(0i32);
let result = std::panic::catch_unwind(AssertUnwindSafe(|| x.use_mut(|_| panic!("test"))));
assert!(result.is_err());
x.use_mut(|_| ());
}
}