use std::cell::UnsafeCell;
use std::fmt::{self, Debug, Formatter};
use std::mem::{self, MaybeUninit};
use std::ptr;
use std::sync::atomic::{self, AtomicBool, AtomicUsize, Ordering};
use std::thread::{self, Thread};
use scopeguard::defer_on_unwind;
#[cfg(test)]
mod tests;
const STATE_CLEAR: usize = 0;
const STATE_INIT: usize = usize::MAX;
const BUSY_BIT: usize = 1;
#[repr(align(2))]
struct Parked {
wake: AtomicBool,
thread: Thread,
next: *const Parked,
}
impl Parked {
pub unsafe fn park(state: &AtomicUsize, mut guess: usize) {
let mut parked = Self {
wake: AtomicBool::new(false),
thread: thread::current(),
next: ptr::null(),
};
let parked_ptr = &parked as *const Self as usize;
debug_assert!((parked_ptr & BUSY_BIT) == 0);
loop {
parked.next = (guess & !BUSY_BIT) as *const Parked;
match state.compare_exchange_weak(
guess,
parked_ptr | (guess & BUSY_BIT),
Ordering::Release,
Ordering::Relaxed,
) {
Ok(..) => {
break;
}
Err(STATE_INIT) => {
atomic::fence(Ordering::Acquire);
return;
}
Err(current) => {
guess = current;
}
};
}
loop {
thread::park();
if parked.wake.load(Ordering::Acquire) {
return;
}
}
}
pub unsafe fn unpark_all(mut head: *const Self) {
while !head.is_null() {
head = Self::unpark(head);
}
}
unsafe fn unpark(parked: *const Self) -> *const Self {
let thread = (*parked).thread.clone();
let next = (*parked).next;
assert!(!(*parked).wake.swap(true, Ordering::Release));
thread.unpark();
next
}
}
unsafe impl Sync for Parked {}
pub struct WaitCell<T> {
state: AtomicUsize,
value: UnsafeCell<MaybeUninit<T>>,
}
impl<T> WaitCell<T> {
#[must_use]
pub const fn new() -> Self {
Self {
state: AtomicUsize::new(STATE_CLEAR),
value: UnsafeCell::new(MaybeUninit::uninit()),
}
}
#[must_use]
pub const fn initialized(value: T) -> Self {
Self {
state: AtomicUsize::new(STATE_INIT),
value: UnsafeCell::new(MaybeUninit::new(value)),
}
}
#[must_use]
unsafe fn get_ptr(&self) -> *mut T {
(&mut *self.value.get()).as_mut_ptr()
}
#[must_use]
unsafe fn get_ref(&self) -> &T {
&*self.get_ptr()
}
#[allow(clippy::mut_from_ref)]
#[must_use]
unsafe fn get_mut(&self) -> &mut T {
&mut *self.get_ptr()
}
unsafe fn wake(&self, value: T) -> &T {
ptr::write(self.get_ptr(), value);
let state = self.state.swap(STATE_INIT, Ordering::Release);
debug_assert!((state & BUSY_BIT) != 0);
Parked::unpark_all((state & !BUSY_BIT) as *const Parked);
self.get_ref()
}
unsafe fn wake_with<F: FnOnce() -> T>(&self, func: F) -> &T {
self.wake({
defer_on_unwind! {
self.state.fetch_and(!BUSY_BIT, Ordering::Relaxed);
}
func()
})
}
unsafe fn wait(&self, state: usize) -> &T {
if state != STATE_INIT {
Parked::park(&self.state, state);
}
self.get_ref()
}
pub fn init(&self, value: T) -> &T {
let state = self.state.fetch_or(BUSY_BIT, Ordering::Relaxed);
assert!((state & BUSY_BIT) == 0, "WaitCell is not uninitialized");
unsafe { self.wake(value) }
}
pub fn try_init<F: FnOnce() -> T>(&self, func: F) -> bool {
let state = self.state.fetch_or(BUSY_BIT, Ordering::Relaxed);
if (state & BUSY_BIT) != 0 {
return false;
}
unsafe {
self.wake_with(func);
}
true
}
#[must_use]
pub fn get_or_init<F: FnOnce() -> T>(&self, func: F) -> &T {
let state = self.state.fetch_or(BUSY_BIT, Ordering::Acquire);
if (state & BUSY_BIT) == 0 {
unsafe { self.wake_with(func) }
} else {
unsafe { self.wait(state) }
}
}
#[must_use]
pub fn get(&self) -> &T {
unsafe { self.wait(self.state.load(Ordering::Acquire)) }
}
#[must_use]
pub fn try_get(&self) -> Option<&T> {
(self.state.load(Ordering::Acquire) == STATE_INIT).then(|| unsafe { self.get_ref() })
}
pub fn set(&mut self, value: T) -> &mut T {
unsafe {
if self.is_set() {
*self.get_mut() = value;
} else {
ptr::write(self.get_ptr(), value);
*self.state.get_mut() = STATE_INIT;
}
self.get_mut()
}
}
pub fn unset(&mut self) -> bool {
self.is_set()
.then(|| unsafe {
*self.state.get_mut() = STATE_CLEAR;
ptr::drop_in_place(self.get_ptr());
})
.is_some()
}
#[must_use]
pub fn is_set(&mut self) -> bool {
let state = *self.state.get_mut();
debug_assert!(state == STATE_CLEAR || state == STATE_INIT);
state == STATE_INIT
}
#[must_use]
pub fn as_inner(&mut self) -> Option<&mut T> {
self.is_set().then(move || unsafe { self.get_mut() })
}
#[must_use]
pub fn into_inner(mut self) -> Option<T> {
self.is_set().then(|| unsafe {
let result = ptr::read(self.get_ptr());
mem::forget(self);
result
})
}
#[must_use]
pub fn as_ptr(&self) -> *const T {
unsafe { self.get_ptr() }
}
#[must_use]
pub fn as_mut_ptr(&mut self) -> *mut T {
unsafe { self.get_ptr() }
}
pub unsafe fn set_init(&mut self) {
*self.state.get_mut() = STATE_INIT;
}
pub fn set_uninit(&mut self) {
*self.state.get_mut() = STATE_CLEAR;
}
}
impl<T: Default> WaitCell<T> {
#[must_use]
pub fn with_default() -> Self {
Self::initialized(Default::default())
}
}
impl<T> Drop for WaitCell<T> {
fn drop(&mut self) {
let state = *self.state.get_mut();
if *self.state.get_mut() == STATE_INIT {
unsafe { ptr::drop_in_place(self.get_ptr()); }
} else {
debug_assert_eq!(state & !BUSY_BIT, STATE_CLEAR);
}
}
}
impl<T> Debug for WaitCell<T> {
fn fmt(&self, f: &mut Formatter) -> fmt::Result {
let state = self.state.load(Ordering::Relaxed);
if state == STATE_INIT {
f.debug_struct("WaitCell")
.field("state", &"initialized")
.field("has_waiter", &false)
.finish()
} else if (state & BUSY_BIT) == 0 {
f.debug_struct("WaitCell")
.field("state", &"uninitialized")
.field("has_waiter", &((state & !BUSY_BIT) != 0))
.finish()
} else {
f.debug_struct("WaitCell")
.field("state", &"initializing")
.field("has_waiter", &((state & !BUSY_BIT) != 0))
.finish()
}
}
}
impl<T> Default for WaitCell<T> {
fn default() -> Self {
Self::new()
}
}
impl<T> From<T> for WaitCell<T> {
fn from(value: T) -> Self {
Self::initialized(value)
}
}
unsafe impl<T: Sync> Sync for WaitCell<T> {}