use {
alloc::fmt,
core::{cell::UnsafeCell, convert::Infallible, mem::MaybeUninit},
};
pub struct Once<T = ()> {
initialized: UnsafeCell<bool>,
panicked: UnsafeCell<bool>,
data: UnsafeCell<MaybeUninit<T>>,
}
impl<T: fmt::Debug> fmt::Debug for Once<T> {
fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
let mut d = f.debug_tuple("Once");
let d = if let Some(x) = self.get() {
d.field(&x)
} else {
d.field(&format_args!("<uninit>"))
};
d.finish()
}
}
impl<T> From<T> for Once<T> {
fn from(data: T) -> Self {
Self::initialized(data)
}
}
impl<T> Drop for Once<T> {
fn drop(&mut self) {
if self.is_completed() {
unsafe {
core::ptr::drop_in_place((*self.data.get()).as_mut_ptr());
}
}
}
}
unsafe impl<T: Send + Sync> Sync for Once<T> {}
unsafe impl<T: Send> Send for Once<T> {}
impl<T> Once<T> {
#[allow(clippy::declare_interior_mutable_const)]
pub const INIT: Self = Self {
initialized: UnsafeCell::new(false),
panicked: UnsafeCell::new(false),
data: UnsafeCell::new(MaybeUninit::uninit()),
};
pub const fn new() -> Self {
Self::INIT
}
pub fn as_mut_ptr(&self) -> *mut T {
self.data.get().cast::<T>()
}
unsafe fn force_get(&self) -> &T {
unsafe { &*(*self.data.get()).as_ptr() }
}
unsafe fn force_get_mut(&mut self) -> &mut T {
unsafe { &mut *(*self.data.get()).as_mut_ptr() }
}
unsafe fn force_into_inner(self) -> T {
unsafe { (*self.data.get()).as_ptr().read() }
}
pub fn call_once<F: FnOnce() -> T>(&self, f: F) -> &T {
match self.try_call_once(|| Ok::<T, Infallible>(f())) {
Ok(x) => x,
Err(void) => match void {},
}
}
pub fn try_call_once<F: FnOnce() -> Result<T, E>, E>(&self, f: F) -> Result<&T, E> {
unsafe {
if *self.panicked.get() {
panic!("Initialization panicked");
} else if self.is_completed() {
Ok(self.force_get())
} else {
*self.panicked.get() = true;
let value = f();
*self.panicked.get() = false;
(*self.data.get()).as_mut_ptr().write(value?);
*self.initialized.get() = true;
Ok(self.force_get())
}
}
}
pub fn get(&self) -> Option<&T> {
unsafe { self.is_completed().then(|| self.force_get()) }
}
pub fn get_mut(&mut self) -> Option<&mut T> {
unsafe { self.is_completed().then(|| self.force_get_mut()) }
}
pub unsafe fn get_mut_unchecked(&mut self) -> &mut T {
debug_assert!(
self.is_completed(),
"Attempted to access an unintialized Once. If this was to run without debug checks, this would be undefined behavior. This is a serious bug and you must fix it.",
);
unsafe { self.force_get_mut() }
}
pub unsafe fn get_unchecked(&self) -> &T {
debug_assert!(
self.is_completed(),
"Attempted to access an unintialized Once. If this was to run without debug checks, this would be undefined behavior. This is a serious bug and you must fix it.",
);
unsafe { self.force_get() }
}
pub const fn initialized(data: T) -> Self {
Self {
initialized: UnsafeCell::new(true),
panicked: UnsafeCell::new(false),
data: UnsafeCell::new(MaybeUninit::new(data)),
}
}
pub unsafe fn into_inner_unchecked(self) -> T {
debug_assert!(
self.is_completed(),
"Attempted to access an unintialized Once. If this was to run without debug checks, this would be undefined behavior. This is a serious bug and you must fix it.",
);
unsafe { self.force_into_inner() }
}
pub fn is_completed(&self) -> bool {
unsafe { *self.initialized.get() }
}
pub fn poll(&self) -> Option<&T> {
self.get()
}
pub fn wait(&self) -> &T {
self.get()
.expect("Waited on uninitialized Once, who are you waiting for?")
}
}
impl<T> Default for Once<T> {
fn default() -> Self {
Self::new()
}
}
#[cfg(test)]
mod tests {
use std::prelude::v1::*;
use std::sync::atomic::Ordering;
use std::sync::atomic::AtomicU32;
use std::sync::mpsc::channel;
use std::thread;
use super::*;
#[test]
fn smoke_once() {
static O: Once = Once::new();
let mut a = 0;
O.call_once(|| a += 1);
assert_eq!(a, 1);
O.call_once(|| a += 1);
assert_eq!(a, 1);
}
#[test]
fn smoke_once_value() {
static O: Once<usize> = Once::new();
let a = O.call_once(|| 1);
assert_eq!(*a, 1);
let b = O.call_once(|| 2);
assert_eq!(*b, 1);
}
#[test]
fn stampede_once() {
static O: Once = Once::new();
static mut RUN: bool = false;
let (tx, rx) = channel();
let mut ts = Vec::new();
for _ in 0..10 {
let tx = tx.clone();
ts.push(thread::spawn(move || {
for _ in 0..4 {
thread::yield_now()
}
unsafe {
O.call_once(|| {
assert!(!RUN);
RUN = true;
});
assert!(RUN);
}
tx.send(()).unwrap();
}));
}
unsafe {
O.call_once(|| {
assert!(!RUN);
RUN = true;
});
assert!(RUN);
}
for _ in 0..10 {
rx.recv().unwrap();
}
for t in ts {
t.join().unwrap();
}
}
#[test]
fn get() {
static INIT: Once<usize> = Once::new();
assert!(INIT.get().is_none());
INIT.call_once(|| 2);
assert_eq!(INIT.get().copied(), Some(2));
}
#[test]
fn get_no_wait() {
static INIT: Once<usize> = Once::new();
assert!(INIT.get().is_none());
let t = thread::spawn(move || {
INIT.call_once(|| {
thread::sleep(std::time::Duration::from_secs(3));
42
});
});
assert!(INIT.get().is_none());
t.join().unwrap();
}
#[test]
fn poll() {
static INIT: Once<usize> = Once::new();
assert!(INIT.poll().is_none());
INIT.call_once(|| 3);
assert_eq!(INIT.poll().copied(), Some(3));
}
#[test]
fn wait() {
static INIT: Once<usize> = Once::new();
let t = std::thread::spawn(|| {
assert_eq!(*INIT.wait(), 3);
assert!(INIT.is_completed());
});
for _ in 0..4 {
thread::yield_now()
}
assert!(INIT.poll().is_none());
INIT.call_once(|| 3);
t.join().unwrap();
}
#[test]
fn panic() {
use std::panic;
static INIT: Once = Once::new();
let t = panic::catch_unwind(|| {
INIT.call_once(|| panic!());
});
assert!(t.is_err());
let t = panic::catch_unwind(|| {
INIT.call_once(|| {});
});
assert!(t.is_err());
}
#[test]
fn init_constant() {
static O: Once = Once::INIT;
let mut a = 0;
O.call_once(|| a += 1);
assert_eq!(a, 1);
O.call_once(|| a += 1);
assert_eq!(a, 1);
}
static mut CALLED: bool = false;
struct DropTest {}
impl Drop for DropTest {
fn drop(&mut self) {
unsafe {
CALLED = true;
}
}
}
#[test]
fn try_call_once_err() {
let once = Once::<_>::new();
let called = AtomicU32::new(0);
once.try_call_once(|| {
called.fetch_add(1, Ordering::AcqRel);
thread::sleep(std::time::Duration::from_millis(50));
Err(())
})
.ok();
once.call_once(|| {
called.fetch_add(1, Ordering::AcqRel);
});
assert_eq!(called.load(Ordering::Acquire), 2);
}
#[test]
fn drop_occurs_and_skip_uninit_drop() {
unsafe {
CALLED = false;
}
{
let once = Once::<_>::new();
once.call_once(|| DropTest {});
}
assert!(unsafe { CALLED });
unsafe {
CALLED = false;
}
let once = Once::<DropTest>::new();
drop(once);
assert!(unsafe { !CALLED });
}
#[test]
fn call_once_test() {
for _ in 0..20 {
use std::sync::Arc;
use std::sync::atomic::AtomicUsize;
let share = Arc::new(AtomicUsize::new(0));
let once = Arc::new(Once::<_>::new());
for _ in 0..8 {
once.call_once(|| {
share.fetch_add(1, Ordering::SeqCst);
});
}
assert_eq!(1, share.load(Ordering::SeqCst));
}
}
}