once_cell 1.0.0

Single assignment cells and lazy values.
Documentation
// There's a lot of scary concurrent code in this module, but it is copied from
// `std::sync::Once` with two changes:
//   * no poisoning
//   * init function can fail

use std::{
    cell::UnsafeCell,
    marker::PhantomData,
    panic::{RefUnwindSafe, UnwindSafe},
    ptr,
    sync::atomic::{AtomicBool, AtomicUsize, Ordering},
    thread::{self, Thread},
};

#[derive(Debug)]
pub(crate) struct OnceCell<T> {
    // This `state` word is actually an encoded version of just a pointer to a
    // `Waiter`, so we add the `PhantomData` appropriately.
    state: AtomicUsize,
    _marker: PhantomData<*mut Waiter>,
    // FIXME: switch to `std::mem::MaybeUninit` once we are ready to bump MSRV
    // that far. It was stabilized in 1.36.0, so, if you are reading this and
    // it's higher than 1.46.0 outside, please send a PR! ;) (and to the same
    // for `Lazy`, while we are at it).
    value: UnsafeCell<Option<T>>,
}

// Why do we need `T: Send`?
// Thread A creates a `OnceCell` and shares it with
// scoped thread B, which fills the cell, which is
// then destroyed by A. That is, destructor observes
// a sent value.
unsafe impl<T: Sync + Send> Sync for OnceCell<T> {}
unsafe impl<T: Send> Send for OnceCell<T> {}

impl<T: RefUnwindSafe + UnwindSafe> RefUnwindSafe for OnceCell<T> {}
impl<T: UnwindSafe> UnwindSafe for OnceCell<T> {}

// Three states that a OnceCell can be in, encoded into the lower bits of `state` in
// the OnceCell structure.
const INCOMPLETE: usize = 0x0;
const RUNNING: usize = 0x1;
const COMPLETE: usize = 0x2;

// Mask to learn about the state. All other bits are the queue of waiters if
// this is in the RUNNING state.
const STATE_MASK: usize = 0x3;

// Representation of a node in the linked list of waiters in the RUNNING state.
struct Waiter {
    thread: Option<Thread>,
    signaled: AtomicBool,
    next: *mut Waiter,
}

// Helper struct used to clean up after a closure call with a `Drop`
// implementation to also run on panic.
struct Finish<'a> {
    failed: bool,
    my_state: &'a AtomicUsize,
}

impl<T> OnceCell<T> {
    pub(crate) const fn new() -> OnceCell<T> {
        OnceCell {
            state: AtomicUsize::new(INCOMPLETE),
            _marker: PhantomData,
            value: UnsafeCell::new(None),
        }
    }

    pub(crate) fn into_inner(self) -> Option<T> {
        // Because `into_inner` takes `self` by value, the compiler statically verifies
        // that it is not currently borrowed. So it is safe to move out `Option<T>`.
        self.value.into_inner()
    }

    pub(crate) fn get(&self) -> Option<&T> {
        if self.is_completed() {
            let slot: &Option<T> = unsafe { &*self.value.get() };
            match slot {
                Some(value) => Some(value),
                // This unsafe does improve performance, see `examples/bench`.
                None => unsafe { std::hint::unreachable_unchecked() },
            }
        } else {
            None
        }
    }

    pub fn get_or_try_init<F, E>(&self, f: F) -> Result<&T, E>
    where
        F: FnOnce() -> Result<T, E>,
    {
        // Fast path check
        if let Some(value) = self.get() {
            return Ok(value);
        }

        let mut f = Some(f);
        let mut err: Option<E> = None;
        let slot = &self.value;
        get_or_try_init_inner(&self.state, &mut || {
            let f = f.take().unwrap();
            match f() {
                Ok(value) => {
                    unsafe { *slot.get() = Some(value) };
                    true
                }
                Err(e) => {
                    err = Some(e);
                    false
                }
            }
        });
        match err {
            Some(err) => Err(err),
            None => {
                let value: &T = unsafe { &*slot.get() }.as_ref().unwrap();
                Ok(value)
            }
        }
    }

    #[inline]
    fn is_completed(&self) -> bool {
        // An `Acquire` load is enough because that makes all the initialization
        // operations visible to us, and, this being a fast path, weaker
        // ordering helps with performance. This `Acquire` synchronizes with
        // `SeqCst` operations on the slow path.
        self.state.load(Ordering::Acquire) == COMPLETE
    }
}

// Note: this is intentionally monomorphic
#[cold]
fn get_or_try_init_inner(my_state: &AtomicUsize, init: &mut dyn FnMut() -> bool) -> bool {
    // This cold path uses SeqCst consistently because the
    // performance difference really does not matter there, and
    // SeqCst minimizes the chances of something going wrong.
    let mut state = my_state.load(Ordering::SeqCst);

    'outer: loop {
        match state {
            // If we're complete, then there's nothing to do, we just
            // jettison out as we shouldn't run the closure.
            COMPLETE => return true,

            // Otherwise if we see an incomplete state we will attempt to
            // move ourselves into the RUNNING state. If we succeed, then
            // the queue of waiters starts at null (all 0 bits).
            INCOMPLETE => {
                let old = my_state.compare_and_swap(state, RUNNING, Ordering::SeqCst);
                if old != state {
                    state = old;
                    continue;
                }

                // Run the initialization routine, letting it know if we're
                // poisoned or not. The `Finish` struct is then dropped, and
                // the `Drop` implementation here is responsible for waking
                // up other waiters both in the normal return and panicking
                // case.
                let mut complete = Finish { failed: true, my_state };
                let success = init();
                complete.failed = !success;
                return success;
            }

            // All other values we find should correspond to the RUNNING
            // state with an encoded waiter list in the more significant
            // bits. We attempt to enqueue ourselves by moving us to the
            // head of the list and bail out if we ever see a state that's
            // not RUNNING.
            _ => {
                assert!(state & STATE_MASK == RUNNING);
                let mut node = Waiter {
                    thread: Some(thread::current()),
                    signaled: AtomicBool::new(false),
                    next: ptr::null_mut(),
                };
                let me = &mut node as *mut Waiter as usize;
                assert!(me & STATE_MASK == 0);

                while state & STATE_MASK == RUNNING {
                    node.next = (state & !STATE_MASK) as *mut Waiter;
                    let old = my_state.compare_and_swap(state, me | RUNNING, Ordering::SeqCst);
                    if old != state {
                        state = old;
                        continue;
                    }

                    // Once we've enqueued ourselves, wait in a loop.
                    // Afterwards reload the state and continue with what we
                    // were doing from before.
                    while !node.signaled.load(Ordering::SeqCst) {
                        thread::park();
                    }
                    state = my_state.load(Ordering::SeqCst);
                    continue 'outer;
                }
            }
        }
    }
}

impl Drop for Finish<'_> {
    fn drop(&mut self) {
        // Swap out our state with however we finished. We should only ever see
        // an old state which was RUNNING.
        let queue = if self.failed {
            self.my_state.swap(INCOMPLETE, Ordering::SeqCst)
        } else {
            self.my_state.swap(COMPLETE, Ordering::SeqCst)
        };
        assert_eq!(queue & STATE_MASK, RUNNING);

        // Decode the RUNNING to a list of waiters, then walk that entire list
        // and wake them up. Note that it is crucial that after we store `true`
        // in the node it can be free'd! As a result we load the `thread` to
        // signal ahead of time and then unpark it after the store.
        unsafe {
            let mut queue = (queue & !STATE_MASK) as *mut Waiter;
            while !queue.is_null() {
                let next = (*queue).next;
                let thread = (*queue).thread.take().unwrap();
                (*queue).signaled.store(true, Ordering::SeqCst);
                thread.unpark();
                queue = next;
            }
        }
    }
}

// These test are snatched from std as well.
#[cfg(test)]
mod tests {
    use std::panic;
    #[cfg(not(miri))] // miri doesn't support threads
    use std::{sync::mpsc::channel, thread};

    use super::OnceCell;

    impl<T> OnceCell<T> {
        fn init(&self, f: impl FnOnce() -> T) {
            enum Void {}
            let _ = self.get_or_try_init(|| Ok::<T, Void>(f()));
        }
    }

    #[test]
    fn smoke_once() {
        static O: OnceCell<()> = OnceCell::new();
        let mut a = 0;
        O.init(|| a += 1);
        assert_eq!(a, 1);
        O.init(|| a += 1);
        assert_eq!(a, 1);
    }

    #[test]
    #[cfg(not(miri))] // miri doesn't support threads
    fn stampede_once() {
        static O: OnceCell<()> = OnceCell::new();
        static mut RUN: bool = false;

        let (tx, rx) = channel();
        for _ in 0..10 {
            let tx = tx.clone();
            thread::spawn(move || {
                for _ in 0..4 {
                    thread::yield_now()
                }
                unsafe {
                    O.init(|| {
                        assert!(!RUN);
                        RUN = true;
                    });
                    assert!(RUN);
                }
                tx.send(()).unwrap();
            });
        }

        unsafe {
            O.init(|| {
                assert!(!RUN);
                RUN = true;
            });
            assert!(RUN);
        }

        for _ in 0..10 {
            rx.recv().unwrap();
        }
    }

    #[test]
    #[cfg(not(miri))] // miri doesn't support panics
    fn poison_bad() {
        static O: OnceCell<()> = OnceCell::new();

        // poison the once
        let t = panic::catch_unwind(|| {
            O.init(|| panic!());
        });
        assert!(t.is_err());

        // we can subvert poisoning, however
        let mut called = false;
        O.init(|| {
            called = true;
        });
        assert!(called);

        // once any success happens, we stop propagating the poison
        O.init(|| {});
    }

    #[test]
    #[cfg(not(miri))] // miri doesn't support panics
    fn wait_for_force_to_finish() {
        static O: OnceCell<()> = OnceCell::new();

        // poison the once
        let t = panic::catch_unwind(|| {
            O.init(|| panic!());
        });
        assert!(t.is_err());

        // make sure someone's waiting inside the once via a force
        let (tx1, rx1) = channel();
        let (tx2, rx2) = channel();
        let t1 = thread::spawn(move || {
            O.init(|| {
                tx1.send(()).unwrap();
                rx2.recv().unwrap();
            });
        });

        rx1.recv().unwrap();

        // put another waiter on the once
        let t2 = thread::spawn(|| {
            let mut called = false;
            O.init(|| {
                called = true;
            });
            assert!(!called);
        });

        tx2.send(()).unwrap();

        assert!(t1.join().is_ok());
        assert!(t2.join().is_ok());
    }
}