#![cfg_attr(not(test), no_std)]
#![deny(missing_docs)]
use core::cell::UnsafeCell;
use core::mem::{needs_drop, MaybeUninit};
use core::task::Poll;
use portable_atomic::{self as atomic, AtomicUsize};
mod init_once_state {
pub const EMPTY: usize = 0;
pub const INITIALIZING: usize = 1;
pub const INITIALIZED: usize = 2;
}
pub enum InitState<'a, T> {
Initializing,
Initialized(&'a T),
Polling(PollInit<'a, T>),
}
pub struct InitOnce<T> {
cell: UnsafeCell<MaybeUninit<T>>,
state: AtomicUsize,
}
pub struct PollInit<'a, T> {
init_once: &'a InitOnce<T>,
}
unsafe impl<T: Sync> Sync for InitOnce<T> {}
impl<T> Drop for InitOnce<T> {
fn drop(&mut self) {
if needs_drop::<T>()
&& self.state.load(atomic::Ordering::SeqCst) == init_once_state::INITIALIZED
{
unsafe {
self.cell.get_mut().assume_init_drop();
}
}
}
}
impl<T> Default for InitOnce<T> {
#[inline]
fn default() -> Self {
Self::new()
}
}
impl<T> InitOnce<T> {
pub const fn new() -> Self {
Self {
cell: UnsafeCell::new(MaybeUninit::uninit()),
state: AtomicUsize::new(init_once_state::EMPTY),
}
}
#[must_use]
fn poll_init_begin(&self) -> PollInit<'_, T> {
PollInit { init_once: self }
}
#[must_use = "The state of an InitOnce (i.e. InitState) must always be consumed. If you do \
not poll the value initializer to completion, the value will never be initialized."]
pub fn state(&self) -> InitState<'_, T> {
self.state
.compare_exchange(
init_once_state::EMPTY,
init_once_state::INITIALIZING,
atomic::Ordering::SeqCst,
atomic::Ordering::SeqCst,
)
.map_or_else(
|current_value| match current_value {
init_once_state::INITIALIZING => InitState::Initializing,
init_once_state::INITIALIZED => {
InitState::Initialized(unsafe { (*self.cell.get()).assume_init_ref() })
}
_ => unreachable!(),
},
|_| InitState::Polling(self.poll_init_begin()),
)
}
pub fn try_init<F>(&self, mut init: F) -> Option<&T>
where
F: FnMut() -> T,
{
match self.state() {
InitState::Initialized(value) => Some(value),
InitState::Initializing => None,
InitState::Polling(poller) => match poller.poll_init(|| Poll::Ready(init())) {
Poll::Ready(value) => Some(value),
Poll::Pending => unreachable!(),
},
}
}
}
impl<'init_once, T> PollInit<'init_once, T> {
pub fn poll_init<F>(&self, mut init: F) -> Poll<&'init_once T>
where
F: FnMut() -> Poll<T>,
{
let value = core::task::ready!(init());
let slot = unsafe { (*self.init_once.cell.get()).as_mut_ptr() };
unsafe {
core::ptr::write(slot, value);
}
self.init_once
.state
.store(init_once_state::INITIALIZED, atomic::Ordering::SeqCst);
Poll::Ready(unsafe { (*self.init_once.cell.get()).assume_init_ref() })
}
}
#[cfg(test)]
mod tests {
use std::future::{self, Future};
use std::sync::{Arc, Mutex};
use std::thread;
use super::*;
struct TrackDrop {
count: Arc<Mutex<usize>>,
}
impl Drop for TrackDrop {
fn drop(&mut self) {
*self.count.lock().unwrap() += 1;
}
}
#[tokio::test]
async fn dropped_once_if_init() {
let mut init_once = Arc::new(InitOnce::new());
let count = Arc::new(Mutex::new(0));
assert_eq!(
*Arc::get_mut(&mut init_once).unwrap().state.get_mut(),
init_once_state::EMPTY
);
let tasks: Vec<_> = (0..10)
.map(|_| {
let init_once = Arc::clone(&init_once);
let count = Arc::clone(&count);
tokio::spawn(async move {
if let InitState::Polling(poller) = init_once.state() {
let fut = future::ready(TrackDrop { count });
let mut pinned_fut = std::pin::pin!(fut);
let TrackDrop {
count: current_count,
} = future::poll_fn(|cx| poller.poll_init(|| pinned_fut.as_mut().poll(cx)))
.await;
assert_eq!(*current_count.lock().unwrap(), 0);
}
})
})
.collect();
for handle in tasks {
handle.await.unwrap();
}
assert_eq!(*count.lock().unwrap(), 0);
assert_eq!(
*Arc::get_mut(&mut init_once).unwrap().state.get_mut(),
init_once_state::INITIALIZED
);
drop(init_once);
assert_eq!(*count.lock().unwrap(), 1);
}
#[test]
fn never_poll_init() {
let mut init_once = Arc::new(InitOnce::<()>::new());
let count = Arc::new(Mutex::new(0));
assert_eq!(
*Arc::get_mut(&mut init_once).unwrap().state.get_mut(),
init_once_state::EMPTY
);
assert_eq!(*count.lock().unwrap(), 0);
let threads: Vec<_> = (0..10)
.map(|_| {
let init_once = Arc::clone(&init_once);
let count = Arc::clone(&count);
thread::spawn(move || {
if matches!(init_once.state(), InitState::Polling(_)) {
drop(TrackDrop { count });
}
})
})
.collect();
for handle in threads {
handle.join().unwrap();
}
assert_eq!(*count.lock().unwrap(), 1);
assert_eq!(
*Arc::get_mut(&mut init_once).unwrap().state.get_mut(),
init_once_state::INITIALIZING
);
for _ in 0..50 {
assert!(matches!(init_once.state(), InitState::Initializing));
}
drop(init_once);
}
}