#![cfg_attr(not(test), no_std)]
#![deny(missing_docs)]
use core::cell::UnsafeCell;
use core::future::{self, Future};
use core::hint::unreachable_unchecked;
use core::mem::{needs_drop, MaybeUninit};
use core::task::Poll;
use portable_atomic::{self as atomic, AtomicUsize};
mod init_once_state {
pub const EMPTY: usize = 0;
pub const INITIALIZING: usize = 1;
pub const INITIALIZED: usize = 2;
}
#[derive(Debug)]
pub enum InitState<'a, T> {
Initializing,
Initialized(&'a T),
Polling(PollInit<'a, T>),
}
#[derive(Debug)]
pub struct InitOnce<T> {
cell: UnsafeCell<MaybeUninit<T>>,
state: AtomicUsize,
}
#[derive(Debug)]
pub struct PollInit<'a, T> {
polled_to_completion: bool,
init_once: &'a InitOnce<T>,
}
unsafe impl<T: Sync> Sync for InitOnce<T> {}
impl<T> Drop for InitOnce<T> {
fn drop(&mut self) {
if needs_drop::<T>() && *self.state.get_mut() == init_once_state::INITIALIZED {
unsafe {
self.cell.get_mut().assume_init_drop();
}
}
}
}
impl<T> Default for InitOnce<T> {
#[inline]
fn default() -> Self {
Self::new()
}
}
impl<T> InitOnce<T> {
pub const fn new() -> Self {
Self {
cell: UnsafeCell::new(MaybeUninit::uninit()),
state: AtomicUsize::new(init_once_state::EMPTY),
}
}
#[must_use]
fn poll_init_begin(&self) -> PollInit<'_, T> {
PollInit {
init_once: self,
polled_to_completion: false,
}
}
#[must_use = "The state of an InitOnce (i.e. InitState) must always be consumed. If you do \
not poll the value initializer to completion, the value will never be initialized."]
#[inline]
pub fn state(&self) -> InitState<'_, T> {
self.state
.compare_exchange(
init_once_state::EMPTY,
init_once_state::INITIALIZING,
atomic::Ordering::SeqCst,
atomic::Ordering::SeqCst,
)
.map_or_else(
|current_value| match current_value {
init_once_state::INITIALIZING => InitState::Initializing,
init_once_state::INITIALIZED => {
InitState::Initialized({
unsafe { (*self.cell.get()).assume_init_ref() }
})
}
_ => {
unsafe { unreachable_unchecked() }
}
},
|_| unlikely_call(|| InitState::Polling(self.poll_init_begin())),
)
}
#[inline]
pub fn try_init<F>(&self, mut init: F) -> Option<&T>
where
F: FnMut() -> T,
{
match self.state() {
InitState::Initialized(value) => Some(value),
InitState::Initializing => None,
InitState::Polling(mut poller) => match poller.poll_init(|| Poll::Ready(init())) {
Poll::Ready(value) => Some(value),
Poll::Pending => {
unsafe { unreachable_unchecked() }
}
},
}
}
pub async fn try_init_async<F>(&self, init: F) -> Option<&T>
where
F: Future<Output = T>,
{
match self.state() {
InitState::Initialized(value) => Some(value),
InitState::Initializing => None,
InitState::Polling(mut poller) => Some(poller.init_async(init).await),
}
}
pub fn init<F>(&mut self, mut init: F) -> &mut T
where
F: FnMut() -> T,
{
let maybe_uninit = self.cell.get_mut();
if *self.state.get_mut() != init_once_state::INITIALIZED {
unlikely_call(|| {
maybe_uninit.write(init());
*self.state.get_mut() = init_once_state::INITIALIZED;
});
}
unsafe { maybe_uninit.assume_init_mut() }
}
pub async fn init_async<F>(&mut self, init: F) -> &mut T
where
F: Future<Output = T>,
{
let maybe_uninit = self.cell.get_mut();
if *self.state.get_mut() != init_once_state::INITIALIZED {
unlikely_call(|| async {
maybe_uninit.write(init.await);
*self.state.get_mut() = init_once_state::INITIALIZED;
})
.await;
}
unsafe { maybe_uninit.assume_init_mut() }
}
}
impl<'init_once, T> PollInit<'init_once, T> {
pub async fn init_async<F>(&mut self, mut init: F) -> &'init_once T
where
F: Future<Output = T>,
{
let mut pinned_init = core::pin::pin!(init);
future::poll_fn(|cx| self.poll_init(|| pinned_init.as_mut().poll(cx))).await
}
pub fn poll_init<F>(&mut self, mut init: F) -> Poll<&'init_once T>
where
F: FnMut() -> Poll<T>,
{
if self.polled_to_completion {
return unlikely_call(|| {
Poll::Ready({
unsafe { (*self.init_once.cell.get()).assume_init_ref() }
})
});
}
let value = core::task::ready!(init());
let slot = unsafe { (*self.init_once.cell.get()).as_mut_ptr() };
unsafe {
core::ptr::write(slot, value);
}
self.init_once
.state
.store(init_once_state::INITIALIZED, atomic::Ordering::SeqCst);
self.polled_to_completion = true;
Poll::Ready({
unsafe { (*self.init_once.cell.get()).assume_init_ref() }
})
}
}
#[cold]
#[inline(never)]
fn unlikely_call<T, F: FnOnce() -> T>(f: F) -> T {
f()
}
#[cfg(test)]
mod tests {
use std::sync::{Arc, Mutex};
use std::thread;
use super::*;
struct TrackDrop {
count: Arc<Mutex<usize>>,
}
impl Drop for TrackDrop {
fn drop(&mut self) {
*self.count.lock().unwrap() += 1;
}
}
#[test]
fn try_init_wont_block() {
struct Shared {
init_once: InitOnce<()>,
thread_barrier: std::sync::Barrier,
init_barrier: std::sync::Barrier,
}
let shared = Arc::new(Shared {
init_once: InitOnce::new(),
thread_barrier: std::sync::Barrier::new(2),
init_barrier: std::sync::Barrier::new(2),
});
let shared2 = Arc::clone(&shared);
let handle = std::thread::spawn(move || {
shared2.thread_barrier.wait();
assert!(shared2
.init_once
.try_init(|| {
shared2.init_barrier.wait();
})
.is_some());
});
shared.thread_barrier.wait();
std::thread::sleep(std::time::Duration::from_millis(50));
assert!(shared.init_once.try_init(|| panic!()).is_none());
shared.init_barrier.wait();
handle.join().unwrap();
assert!(shared.init_once.try_init(|| panic!()).is_some());
}
#[tokio::test]
async fn try_init_async_wont_block() {
struct Shared {
init_once: InitOnce<()>,
thread_barrier: tokio::sync::Barrier,
init_barrier: tokio::sync::Barrier,
}
let shared = Arc::new(Shared {
init_once: InitOnce::new(),
thread_barrier: tokio::sync::Barrier::new(2),
init_barrier: tokio::sync::Barrier::new(2),
});
let shared2 = Arc::clone(&shared);
let handle = tokio::spawn(async move {
shared2.thread_barrier.wait().await;
assert!(shared2
.init_once
.try_init_async(async {
shared2.init_barrier.wait().await;
})
.await
.is_some());
});
shared.thread_barrier.wait().await;
tokio::time::sleep(tokio::time::Duration::from_millis(50)).await;
assert!(shared.init_once.try_init(|| panic!()).is_none());
assert!(shared
.init_once
.try_init_async(async { panic!() })
.await
.is_none());
shared.init_barrier.wait().await;
handle.await.unwrap();
assert!(shared.init_once.try_init(|| panic!()).is_some());
assert!(shared
.init_once
.try_init_async(async { panic!() })
.await
.is_some());
}
#[test]
fn init_mut_only_once() {
let mut initialized = 0;
let mut init_once = InitOnce::new();
for _ in 0..10 {
init_once.init(|| {
initialized += 1;
});
}
assert_eq!(initialized, 1);
}
#[tokio::test]
async fn init_mut_async_only_once() {
let mut initialized = 0;
let mut init_once = InitOnce::new();
for _ in 0..10 {
init_once
.init_async(async {
initialized += 1;
})
.await;
}
assert_eq!(initialized, 1);
}
#[tokio::test]
async fn dropped_once_if_init() {
let mut init_once = Arc::new(InitOnce::new());
let count = Arc::new(Mutex::new(0));
assert_eq!(
*Arc::get_mut(&mut init_once).unwrap().state.get_mut(),
init_once_state::EMPTY
);
let tasks: Vec<_> = (0..10)
.map(|_| {
let init_once = Arc::clone(&init_once);
let count = Arc::clone(&count);
tokio::spawn(async move {
if let InitState::Polling(mut poller) = init_once.state() {
let TrackDrop {
count: current_count,
} = poller.init_async(future::ready(TrackDrop { count })).await;
assert_eq!(*current_count.lock().unwrap(), 0);
}
})
})
.collect();
for handle in tasks {
handle.await.unwrap();
}
assert_eq!(*count.lock().unwrap(), 0);
assert_eq!(
*Arc::get_mut(&mut init_once).unwrap().state.get_mut(),
init_once_state::INITIALIZED
);
drop(init_once);
assert_eq!(*count.lock().unwrap(), 1);
}
#[test]
fn never_poll_init() {
let mut init_once = Arc::new(InitOnce::<()>::new());
let count = Arc::new(Mutex::new(0));
assert_eq!(
*Arc::get_mut(&mut init_once).unwrap().state.get_mut(),
init_once_state::EMPTY
);
assert_eq!(*count.lock().unwrap(), 0);
let threads: Vec<_> = (0..10)
.map(|_| {
let init_once = Arc::clone(&init_once);
let count = Arc::clone(&count);
thread::spawn(move || {
if matches!(init_once.state(), InitState::Polling(_)) {
drop(TrackDrop { count });
}
})
})
.collect();
for handle in threads {
handle.join().unwrap();
}
assert_eq!(*count.lock().unwrap(), 1);
assert_eq!(
*Arc::get_mut(&mut init_once).unwrap().state.get_mut(),
init_once_state::INITIALIZING
);
for _ in 0..50 {
assert!(matches!(init_once.state(), InitState::Initializing));
}
drop(init_once);
}
#[test]
fn poll_init_only_once() {
let mut once = InitOnce::new();
let count = Arc::new(Mutex::new(0));
assert_eq!(*count.lock().unwrap(), 0);
if let InitState::Polling(mut poller) = once.state() {
for i in 0..10 {
_ = poller.poll_init(|| {
if i == 0 {
Poll::Ready((
420,
TrackDrop {
count: Arc::clone(&count),
},
))
} else {
unreachable!()
}
});
}
}
let value = once.init(|| unreachable!());
assert_eq!(value.0, 420);
assert_eq!(*count.lock().unwrap(), 0);
drop(once);
assert_eq!(*count.lock().unwrap(), 1);
}
#[tokio::test]
async fn init_async_drop_future() {
let mut once = InitOnce::new();
let mut completed = false;
{
let mut future = once.init_async(async {
tokio::task::yield_now().await;
completed = true;
420
});
let mut pinned_future = core::pin::pin!(future);
std::future::poll_fn(|cx| match pinned_future.as_mut().poll(cx) {
Poll::Ready(_) => unreachable!(),
Poll::Pending => Poll::Ready(()),
})
.await;
}
assert!(!completed, "future was dropped before completing");
assert_ne!(
*once.state.get_mut(),
init_once_state::INITIALIZED,
"cell should not have been initialized",
);
}
}