#![doc = include_str!("../README.md")]
use std::{cell::Cell, marker::PhantomPinned, pin::Pin, ptr::NonNull};
pub struct ValueGuard<T> {
data: Cell<T>,
ref_guard: Cell<Option<NonNull<RefGuard<T>>>>,
_marker: PhantomPinned,
}
impl<T> ValueGuard<T> {
#[inline]
pub fn new(data: T) -> Self {
Self {
data: Cell::new(data),
ref_guard: Cell::new(None),
_marker: PhantomPinned,
}
}
#[inline]
pub fn registration<'a>(self: Pin<&'a Self>) -> GuardRegistration<'a, T> {
GuardRegistration { value_guard: self }
}
#[inline]
pub fn set(&self, value: T) {
self.data.set(value);
}
}
#[inline]
fn invalidate_value_guard<T>(guard: NonNull<ValueGuard<T>>) {
unsafe { (*guard.as_ptr()).ref_guard.set(None) };
}
impl<T: Copy> ValueGuard<T> {
#[inline]
pub fn get(&self) -> T {
self.data.get()
}
}
impl<T> Drop for ValueGuard<T> {
#[inline]
fn drop(&mut self) {
if let Some(guard) = self.ref_guard.get() {
invalidate_ref_guard(guard);
}
}
}
pub struct RefGuard<T> {
value_guard: Cell<Option<NonNull<ValueGuard<T>>>>,
_marker: PhantomPinned,
}
impl<T> RefGuard<T> {
#[inline]
pub fn new() -> Self {
Self {
value_guard: Cell::new(None),
_marker: PhantomPinned,
}
}
}
#[inline]
fn invalidate_ref_guard<T>(guard: NonNull<RefGuard<T>>) {
unsafe { (*guard.as_ptr()).value_guard.set(None) };
}
impl<T: Copy> RefGuard<T> {
#[inline]
pub fn get(&self) -> Option<T> {
self.value_guard
.get()
.map(|guard| unsafe { (*guard.as_ptr()).get() })
}
}
impl<T> Drop for RefGuard<T> {
#[inline]
fn drop(&mut self) {
if let Some(guard) = self.value_guard.get() {
invalidate_value_guard(guard);
}
}
}
impl<T> Default for RefGuard<T> {
#[inline]
fn default() -> Self {
Self::new()
}
}
pub struct GuardRegistration<'a, T> {
value_guard: Pin<&'a ValueGuard<T>>,
}
impl<'a, T> GuardRegistration<'a, T> {
pub fn register(self, slot: Pin<&'a RefGuard<T>>) {
if let Some(old_guard) = slot
.value_guard
.replace(Some(self.value_guard.get_ref().into()))
{
invalidate_value_guard(old_guard);
}
if let Some(old_guard) = self
.value_guard
.ref_guard
.replace(Some(slot.get_ref().into()))
{
invalidate_ref_guard(old_guard);
}
}
}
#[cfg(test)]
mod test {
use std::{mem, pin};
use super::*;
#[test]
fn basic() {
let weak = pin::pin!(RefGuard::new());
{
let strong = pin::pin!(ValueGuard::new(2));
strong.as_ref().registration().register(weak.as_ref());
assert_eq!(strong.get(), 2);
assert_eq!(weak.get(), Some(2));
strong.as_ref().set(3);
assert_eq!(strong.get(), 3);
assert_eq!(weak.get(), Some(3));
}
assert_eq!(weak.get(), None);
}
#[test]
fn multiple_registrations() {
let weak1 = pin::pin!(RefGuard::new());
let weak2 = pin::pin!(RefGuard::new());
{
let strong = pin::pin!(ValueGuard::new(2));
strong.as_ref().registration().register(weak1.as_ref());
assert_eq!(strong.get(), 2);
assert_eq!(weak1.get(), Some(2));
strong.as_ref().set(3);
assert_eq!(strong.get(), 3);
assert_eq!(weak1.get(), Some(3));
strong.as_ref().registration().register(weak2.as_ref());
assert_eq!(weak1.get(), None);
assert_eq!(weak1.value_guard.get(), None);
assert_eq!(strong.get(), 3);
assert_eq!(weak2.get(), Some(3));
strong.as_ref().set(4);
assert_eq!(strong.get(), 4);
assert_eq!(weak2.get(), Some(4));
}
assert_eq!(weak1.get(), None);
assert_eq!(weak2.get(), None);
}
#[test]
#[cfg_attr(miri, ignore)]
fn safe_leak() {
let strong = Box::pin(ValueGuard::new(10));
let weak = pin::pin!(RefGuard::new());
strong.as_ref().registration().register(weak.as_ref());
mem::forget(strong);
assert_eq!(weak.get(), Some(10));
}
}