use std::cell::RefCell;
use std::ops::{Deref, DerefMut};
type CleanupEntry = (Box<dyn FnOnce()>, bool);
thread_local! {
static CLEANUP_STACK: RefCell<Vec<CleanupEntry>> = const { RefCell::new(Vec::new()) };
}
pub struct BailoutGuard<T> {
value: *mut T,
index: usize,
}
unsafe impl<T: Send> Send for BailoutGuard<T> {}
impl<T: 'static> BailoutGuard<T> {
pub fn new(value: T) -> Self {
let boxed = Box::new(value);
let ptr = Box::into_raw(boxed);
let index = CLEANUP_STACK.with(|stack| {
let mut stack = stack.borrow_mut();
let idx = stack.len();
let ptr_copy = ptr;
stack.push((
Box::new(move || {
unsafe {
drop(Box::from_raw(ptr_copy));
}
}),
true, ));
idx
});
Self { value: ptr, index }
}
#[inline]
#[must_use]
pub fn get(&self) -> &T {
unsafe { &*self.value }
}
#[inline]
pub fn get_mut(&mut self) -> &mut T {
unsafe { &mut *self.value }
}
#[must_use]
pub fn into_inner(self) -> T {
CLEANUP_STACK.with(|stack| {
let mut stack = stack.borrow_mut();
if self.index < stack.len() {
stack[self.index].1 = false;
}
});
let value = unsafe { *Box::from_raw(self.value) };
std::mem::forget(self);
value
}
}
impl<T> Deref for BailoutGuard<T> {
type Target = T;
#[inline]
fn deref(&self) -> &T {
unsafe { &*self.value }
}
}
impl<T> DerefMut for BailoutGuard<T> {
#[inline]
fn deref_mut(&mut self) -> &mut T {
unsafe { &mut *self.value }
}
}
impl<T> Drop for BailoutGuard<T> {
fn drop(&mut self) {
CLEANUP_STACK.with(|stack| {
let mut stack = stack.borrow_mut();
if self.index < stack.len() {
stack[self.index].1 = false;
}
});
unsafe {
drop(Box::from_raw(self.value));
}
}
}
#[doc(hidden)]
pub fn run_bailout_cleanups() {
CLEANUP_STACK.with(|stack| {
for (cleanup, active) in stack.borrow_mut().drain(..).rev() {
if active {
cleanup();
}
}
});
}
#[cfg(test)]
mod tests {
use super::*;
use std::sync::Arc;
use std::sync::atomic::{AtomicUsize, Ordering};
fn make_drop_counter(counter: Arc<AtomicUsize>) -> impl Drop + 'static {
struct DropCounter(Arc<AtomicUsize>);
impl Drop for DropCounter {
fn drop(&mut self) {
self.0.fetch_add(1, Ordering::SeqCst);
}
}
DropCounter(counter)
}
#[test]
fn test_normal_drop() {
let drop_count = Arc::new(AtomicUsize::new(0));
CLEANUP_STACK.with(|stack| stack.borrow_mut().clear());
{
let _guard = BailoutGuard::new(make_drop_counter(Arc::clone(&drop_count)));
assert_eq!(drop_count.load(Ordering::SeqCst), 0);
}
assert_eq!(drop_count.load(Ordering::SeqCst), 1);
CLEANUP_STACK.with(|stack| {
assert!(stack.borrow().is_empty() || !stack.borrow()[0].1);
});
}
#[test]
fn test_bailout_cleanup() {
let drop_count = Arc::new(AtomicUsize::new(0));
CLEANUP_STACK.with(|stack| stack.borrow_mut().clear());
let guard = BailoutGuard::new(make_drop_counter(Arc::clone(&drop_count)));
std::mem::forget(guard);
assert_eq!(drop_count.load(Ordering::SeqCst), 0);
run_bailout_cleanups();
assert_eq!(drop_count.load(Ordering::SeqCst), 1);
}
#[test]
fn test_into_inner() {
let drop_count = Arc::new(AtomicUsize::new(0));
CLEANUP_STACK.with(|stack| stack.borrow_mut().clear());
let guard = BailoutGuard::new(make_drop_counter(Arc::clone(&drop_count)));
let value = guard.into_inner();
assert_eq!(drop_count.load(Ordering::SeqCst), 0);
drop(value);
assert_eq!(drop_count.load(Ordering::SeqCst), 1);
}
#[test]
fn test_deref() {
let guard = BailoutGuard::new(String::from("hello"));
assert_eq!(&*guard, "hello");
assert_eq!(guard.len(), 5);
}
#[test]
fn test_deref_mut() {
let mut guard = BailoutGuard::new(String::from("hello"));
guard.push_str(" world");
assert_eq!(&*guard, "hello world");
}
}