use crate::shared::{fence_acquire, invalid_mut, SpinWait, StrictProvenance, Waiter};
use std::{
fmt,
mem::drop,
ptr::NonNull,
sync::atomic::{AtomicPtr, Ordering},
};
const UNINIT: usize = 0;
const CALLING: usize = 1;
const POISONED: usize = 2;
const COMPLETED: usize = 3;
#[derive(Default)]
pub struct Once {
state: AtomicPtr<Waiter>,
}
unsafe impl Send for Once {}
unsafe impl Sync for Once {}
impl fmt::Debug for Once {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
f.debug_struct("Once")
.field("state", &self.state())
.finish()
}
}
impl Once {
pub const fn new() -> Self {
Self {
state: AtomicPtr::new(invalid_mut(UNINIT)),
}
}
pub fn state(&self) -> OnceState {
let state = self.state.load(Ordering::Relaxed);
match state.address() & !Waiter::MASK {
UNINIT => OnceState::New,
CALLING => OnceState::InProgress,
POISONED => OnceState::Poisoned,
COMPLETED => OnceState::Done,
_ => unreachable!("invalid state"),
}
}
#[inline]
pub fn call_once<F>(&self, f: F)
where
F: FnOnce(),
{
let state = self.state.load(Ordering::Acquire);
if state.address() == COMPLETED {
return;
}
self.call_once_slow(false, |_: OnceState| f());
}
#[inline]
pub fn call_once_force<F>(&self, f: F)
where
F: FnOnce(OnceState),
{
let state = self.state.load(Ordering::Acquire);
if state.address() == COMPLETED {
return;
}
self.call_once_slow(true, f);
}
#[cold]
fn call_once_slow<F>(&self, ignore_poison: bool, f: F)
where
F: FnOnce(OnceState),
{
Waiter::with(|waiter| {
let mut spin = SpinWait::default();
let mut state = self.state.load(Ordering::Relaxed);
loop {
if state.address() == COMPLETED {
fence_acquire(&self.state);
return;
}
if state.address() == POISONED && !ignore_poison {
fence_acquire(&self.state);
panic!("Once instance was previously poisoned");
}
if state.address() & !Waiter::MASK == CALLING {
let head = NonNull::new(state.map_address(|addr| addr & Waiter::MASK));
if head.is_none() && spin.try_yield_now() {
state = self.state.load(Ordering::Relaxed);
continue;
}
waiter.next.set(head);
let waiter_ptr = NonNull::from(&*waiter).as_ptr();
let new_state = waiter_ptr.map_address(|addr| addr | CALLING);
if let Err(e) = self.state.compare_exchange_weak(
state,
new_state,
Ordering::Release,
Ordering::Relaxed,
) {
state = e;
continue;
}
assert!(waiter.parker.park(None));
state = self.state.load(Ordering::Relaxed);
continue;
}
match self.state.compare_exchange_weak(
state,
state.with_address(CALLING),
Ordering::Acquire,
Ordering::Relaxed,
) {
Ok(_) => return self.do_call(state, f),
Err(e) => state = e,
}
}
})
}
#[cold]
fn do_call<F>(&self, old_state: *mut Waiter, f: F)
where
F: FnOnce(OnceState),
{
struct StateGuard<'a> {
once: &'a Once,
reset_to: *mut Waiter,
}
impl<'a> Drop for StateGuard<'a> {
fn drop(&mut self) {
let state = self.once.state.swap(self.reset_to, Ordering::AcqRel);
assert_eq!(state.address() & 0b11, CALLING);
let mut waiters = NonNull::new(state.map_address(|addr| addr & Waiter::MASK));
while let Some(waiter) = waiters {
unsafe {
waiters = waiter.as_ref().next.get();
waiter.as_ref().parker.unpark();
}
}
}
}
let mut state_guard = StateGuard {
once: self,
reset_to: old_state.with_address(POISONED),
};
f(match old_state.address() {
UNINIT => OnceState::New,
POISONED => OnceState::Poisoned,
_ => unreachable!("invalid once state on invokation"),
});
state_guard.reset_to = old_state.with_address(COMPLETED);
drop(state_guard);
}
}
#[derive(Copy, Clone, Eq, PartialEq, Debug)]
pub enum OnceState {
New,
Poisoned,
InProgress,
Done,
}
impl OnceState {
#[inline]
pub fn poisoned(self) -> bool {
self == Self::Poisoned
}
#[inline]
pub fn done(self) -> bool {
self == Self::Done
}
}
#[cfg(test)]
mod tests {
use crate::Once;
use std::{panic, sync::mpsc::channel, thread};
#[test]
fn smoke_once() {
static O: Once = Once::new();
let mut a = 0;
O.call_once(|| a += 1);
assert_eq!(a, 1);
O.call_once(|| a += 1);
assert_eq!(a, 1);
}
#[test]
fn stampede_once() {
static O: Once = Once::new();
static mut RUN: bool = false;
let (tx, rx) = channel();
for _ in 0..10 {
let tx = tx.clone();
thread::spawn(move || {
for _ in 0..4 {
thread::yield_now()
}
unsafe {
O.call_once(|| {
assert!(!RUN);
RUN = true;
});
assert!(RUN);
}
tx.send(()).unwrap();
});
}
unsafe {
O.call_once(|| {
assert!(!RUN);
RUN = true;
});
assert!(RUN);
}
for _ in 0..10 {
rx.recv().unwrap();
}
}
#[test]
fn poison_bad() {
static O: Once = Once::new();
let t = panic::catch_unwind(|| {
O.call_once(|| panic!());
});
assert!(t.is_err());
let t = panic::catch_unwind(|| {
O.call_once(|| {});
});
assert!(t.is_err());
let mut called = false;
O.call_once_force(|p| {
called = true;
assert!(p.poisoned())
});
assert!(called);
O.call_once(|| {});
}
#[test]
fn wait_for_force_to_finish() {
static O: Once = Once::new();
let t = panic::catch_unwind(|| {
O.call_once(|| panic!());
});
assert!(t.is_err());
let (tx1, rx1) = channel();
let (tx2, rx2) = channel();
let t1 = thread::spawn(move || {
O.call_once_force(|p| {
assert!(p.poisoned());
tx1.send(()).unwrap();
rx2.recv().unwrap();
});
});
rx1.recv().unwrap();
let t2 = thread::spawn(|| {
let mut called = false;
O.call_once(|| {
called = true;
});
assert!(!called);
});
tx2.send(()).unwrap();
assert!(t1.join().is_ok());
assert!(t2.join().is_ok());
}
#[test]
fn test_once_debug() {
static O: Once = Once::new();
assert_eq!(format!("{:?}", O), "Once { state: New }");
}
}