init_once/
lib.rs

1//! The `init_once` crate provides a mechanic to attempt to read a value without
2//! blocking the caller, in case it is being initialized concurrently. Such an
3//! abstraction might be useful in cache implementations whose consumers might
4//! not want to block on the cache to fill up with data.
5
6#![cfg_attr(not(test), no_std)]
7#![deny(missing_docs)]
8
9use core::cell::UnsafeCell;
10use core::future::{self, Future};
11use core::hint::unreachable_unchecked;
12use core::mem::{needs_drop, MaybeUninit};
13use core::task::Poll;
14
15use portable_atomic::{self as atomic, AtomicUsize};
16
17mod init_once_state {
18    /// The cell is currently empty.
19    pub const EMPTY: usize = 0;
20    /// The cell is being initialized.
21    pub const INITIALIZING: usize = 1;
22    /// The cell is fully initialized.
23    pub const INITIALIZED: usize = 2;
24}
25
26/// Initialization state of an [`InitOnce`] instance.
27#[derive(Debug)]
28pub enum InitState<'a, T> {
29    /// The inner value is currently being initialized by another caller.
30    Initializing,
31    /// The inner value is initialized.
32    Initialized(&'a T),
33    /// The inner value requires initialization via [`PollInit`].
34    Polling(PollInit<'a, T>),
35}
36
37/// Lazily initialize a value of some arbitrary type `T`.
38/// Reading the value doesn't block the caller, if it is
39/// being initialized concurrently.
40#[derive(Debug)]
41pub struct InitOnce<T> {
42    cell: UnsafeCell<MaybeUninit<T>>,
43    state: AtomicUsize,
44}
45
46/// Polling mechanism to initialize a value contained in some
47/// [`InitOnce`] instance.
48#[derive(Debug)]
49pub struct PollInit<'a, T> {
50    polled_to_completion: bool,
51    init_once: &'a InitOnce<T>,
52}
53
54// SAFETY: should be safe to share between threads
55// if `T` is also `Sync`. the atomic operations
56// guarantee that every thread will see the same
57// data.
58unsafe impl<T: Sync> Sync for InitOnce<T> {}
59
60impl<T> Drop for InitOnce<T> {
61    // NB: it is guaranteed that at least one thread calls `Drop`, since
62    // we must block to initialize from at least one thread
63    fn drop(&mut self) {
64        // NB: `InitOnce` doesn't implement clone, so by the time we get dropped,
65        // we are the only live instance (unless we are inside of an `Arc`). we
66        // do not need to go through atomic memory loads to check if the inner
67        // value is initialized.
68        if needs_drop::<T>() && *self.state.get_mut() == init_once_state::INITIALIZED {
69            // SAFETY: the value is initialized, so we can drop it
70            unsafe {
71                self.cell.get_mut().assume_init_drop();
72            }
73        }
74    }
75}
76
77impl<T> Default for InitOnce<T> {
78    #[inline]
79    fn default() -> Self {
80        Self::new()
81    }
82}
83
84impl<T> InitOnce<T> {
85    /// Create an uninitialized [`InitOnce`].
86    pub const fn new() -> Self {
87        Self {
88            cell: UnsafeCell::new(MaybeUninit::uninit()),
89            state: AtomicUsize::new(init_once_state::EMPTY),
90        }
91    }
92
93    #[must_use]
94    fn poll_init_begin(&self) -> PollInit<'_, T> {
95        PollInit {
96            init_once: self,
97            polled_to_completion: false,
98        }
99    }
100
101    /// Query the state of an [`InitOnce`] instance.
102    ///
103    /// If the current state is [`InitState::Polling`], the caller is
104    /// responsible for polling the init function to completion.
105    #[must_use = "The state of an InitOnce (i.e. InitState) must always be consumed. If you do \
106             not poll the value initializer to completion, the value will never be initialized."]
107    #[inline]
108    pub fn state(&self) -> InitState<'_, T> {
109        self.state
110            .compare_exchange(
111                init_once_state::EMPTY,
112                init_once_state::INITIALIZING,
113                atomic::Ordering::SeqCst,
114                atomic::Ordering::SeqCst,
115            )
116            .map_or_else(
117                |current_value| match current_value {
118                    init_once_state::INITIALIZING => InitState::Initializing,
119                    init_once_state::INITIALIZED => {
120                        InitState::Initialized({
121                            // SAFETY: the data returned by the atomic load guarantees
122                            // that the value has been initialized.
123                            unsafe { (*self.cell.get()).assume_init_ref() }
124                        })
125                    }
126                    _ => {
127                        // SAFETY: we attempted to atomically swap `init_once_state::EMPTY`
128                        // with `init_once_state::INITIALIZING` and failed. the safety of
129                        // this `unreachable_unchecked` is guaranteed by the CAS, whose
130                        // previous value couldn't have been `init_once_state::EMPTY`.
131                        unsafe { unreachable_unchecked() }
132                    }
133                },
134                |_| unlikely_call(|| InitState::Polling(self.poll_init_begin())),
135            )
136    }
137
138    /// Attempt to initialize this [`InitOnce`] with the value returned by the closure `init`.
139    #[inline]
140    pub fn try_init<F>(&self, mut init: F) -> Option<&T>
141    where
142        F: FnMut() -> T,
143    {
144        match self.state() {
145            InitState::Initialized(value) => Some(value),
146            InitState::Initializing => None,
147            InitState::Polling(mut poller) => match poller.poll_init(|| Poll::Ready(init())) {
148                Poll::Ready(value) => Some(value),
149                Poll::Pending => {
150                    // SAFETY: we pass `Poll::Ready` to `poll_init` above, therefore
151                    // it is impossible to reach this `Poll::Pending` branch
152                    unsafe { unreachable_unchecked() }
153                }
154            },
155        }
156    }
157
158    /// Attempt to initialize this [`InitOnce`] with the value returned by the future `init`.
159    pub async fn try_init_async<F>(&self, init: F) -> Option<&T>
160    where
161        F: Future<Output = T>,
162    {
163        match self.state() {
164            InitState::Initialized(value) => Some(value),
165            InitState::Initializing => None,
166            InitState::Polling(mut poller) => Some(poller.init_async(init).await),
167        }
168    }
169
170    /// Initialize this [`InitOnce`] with the value returned by the closure `init`.
171    pub fn init<F>(&mut self, mut init: F) -> &mut T
172    where
173        F: FnMut() -> T,
174    {
175        let maybe_uninit = self.cell.get_mut();
176
177        if *self.state.get_mut() != init_once_state::INITIALIZED {
178            unlikely_call(|| {
179                maybe_uninit.write(init());
180                *self.state.get_mut() = init_once_state::INITIALIZED;
181            });
182        }
183
184        // SAFETY: we hold the only reference to the `InitOnce` cell,
185        // and we guarantee that it is always initialized with the
186        // call above
187        unsafe { maybe_uninit.assume_init_mut() }
188    }
189
190    /// Initialize this [`InitOnce`] with the value returned by the future `init`.
191    pub async fn init_async<F>(&mut self, init: F) -> &mut T
192    where
193        F: Future<Output = T>,
194    {
195        let maybe_uninit = self.cell.get_mut();
196
197        if *self.state.get_mut() != init_once_state::INITIALIZED {
198            unlikely_call(|| async {
199                maybe_uninit.write(init.await);
200                *self.state.get_mut() = init_once_state::INITIALIZED;
201            })
202            .await;
203        }
204
205        // SAFETY: we hold the only reference to the `InitOnce` cell,
206        // and we guarantee that it is always initialized with the
207        // call above
208        unsafe { maybe_uninit.assume_init_mut() }
209    }
210}
211
212impl<'init_once, T> PollInit<'init_once, T> {
213    /// Initialize the associated [`InitOnce`] with the given future `init`.
214    pub async fn init_async<F>(&mut self, mut init: F) -> &'init_once T
215    where
216        F: Future<Output = T>,
217    {
218        let mut pinned_init = core::pin::pin!(init);
219        future::poll_fn(|cx| self.poll_init(|| pinned_init.as_mut().poll(cx))).await
220    }
221
222    /// Check if the value returned by `init` is ready.
223    pub fn poll_init<F>(&mut self, mut init: F) -> Poll<&'init_once T>
224    where
225        F: FnMut() -> Poll<T>,
226    {
227        if self.polled_to_completion {
228            return unlikely_call(|| {
229                Poll::Ready({
230                    // SAFETY: the poll method has been polled to completion,
231                    // and we hold an exclusive reference to the poller,
232                    // therefore the state has been initialized and isn't
233                    // being overwritten concurrently
234                    unsafe { (*self.init_once.cell.get()).assume_init_ref() }
235                })
236            });
237        }
238
239        let value = core::task::ready!(init());
240
241        // SAFETY: we CAS'd `init_once_state::EMPTY` with `init_once_state::INITIALIZING`.
242        // the state of the cell cannot be set to `init_once_state::EMPTY` ever again, therefore
243        // only one `PollInit` instance can ever be created. in this method, we hold an exclusive
244        // reference to that sole instance. therefore, we are the only caller able to write a
245        // value to the inner `UnsafeCell`.
246        let slot = unsafe { (*self.init_once.cell.get()).as_mut_ptr() };
247
248        // SAFETY: same as above.
249        unsafe {
250            core::ptr::write(slot, value);
251        }
252
253        self.init_once
254            .state
255            .store(init_once_state::INITIALIZED, atomic::Ordering::SeqCst);
256
257        self.polled_to_completion = true;
258
259        Poll::Ready({
260            // SAFETY: we atomically stored `init_once_state::INITIALIZED`
261            // onto `state`, and initialized the value returned from `init`.
262            unsafe { (*self.init_once.cell.get()).assume_init_ref() }
263        })
264    }
265}
266
267#[cold]
268#[inline(never)]
269fn unlikely_call<T, F: FnOnce() -> T>(f: F) -> T {
270    f()
271}
272
273#[cfg(test)]
274mod tests {
275    use std::sync::{Arc, Mutex};
276    use std::thread;
277
278    use super::*;
279
280    struct TrackDrop {
281        count: Arc<Mutex<usize>>,
282    }
283
284    impl Drop for TrackDrop {
285        fn drop(&mut self) {
286            *self.count.lock().unwrap() += 1;
287        }
288    }
289
290    #[test]
291    fn try_init_wont_block() {
292        struct Shared {
293            init_once: InitOnce<()>,
294            thread_barrier: std::sync::Barrier,
295            init_barrier: std::sync::Barrier,
296        }
297
298        let shared = Arc::new(Shared {
299            init_once: InitOnce::new(),
300            thread_barrier: std::sync::Barrier::new(2),
301            init_barrier: std::sync::Barrier::new(2),
302        });
303
304        let shared2 = Arc::clone(&shared);
305
306        let handle = std::thread::spawn(move || {
307            shared2.thread_barrier.wait();
308
309            assert!(shared2
310                .init_once
311                .try_init(|| {
312                    shared2.init_barrier.wait();
313                })
314                .is_some());
315        });
316
317        shared.thread_barrier.wait();
318        std::thread::sleep(std::time::Duration::from_millis(50));
319        assert!(shared.init_once.try_init(|| panic!()).is_none());
320
321        shared.init_barrier.wait();
322        handle.join().unwrap();
323        assert!(shared.init_once.try_init(|| panic!()).is_some());
324    }
325
326    #[tokio::test]
327    async fn try_init_async_wont_block() {
328        struct Shared {
329            init_once: InitOnce<()>,
330            thread_barrier: tokio::sync::Barrier,
331            init_barrier: tokio::sync::Barrier,
332        }
333
334        let shared = Arc::new(Shared {
335            init_once: InitOnce::new(),
336            thread_barrier: tokio::sync::Barrier::new(2),
337            init_barrier: tokio::sync::Barrier::new(2),
338        });
339
340        let shared2 = Arc::clone(&shared);
341
342        let handle = tokio::spawn(async move {
343            shared2.thread_barrier.wait().await;
344
345            assert!(shared2
346                .init_once
347                .try_init_async(async {
348                    shared2.init_barrier.wait().await;
349                })
350                .await
351                .is_some());
352        });
353
354        shared.thread_barrier.wait().await;
355        tokio::time::sleep(tokio::time::Duration::from_millis(50)).await;
356        assert!(shared.init_once.try_init(|| panic!()).is_none());
357        assert!(shared
358            .init_once
359            .try_init_async(async { panic!() })
360            .await
361            .is_none());
362
363        shared.init_barrier.wait().await;
364        handle.await.unwrap();
365        assert!(shared.init_once.try_init(|| panic!()).is_some());
366        assert!(shared
367            .init_once
368            .try_init_async(async { panic!() })
369            .await
370            .is_some());
371    }
372
373    #[test]
374    fn init_mut_only_once() {
375        let mut initialized = 0;
376        let mut init_once = InitOnce::new();
377
378        for _ in 0..10 {
379            init_once.init(|| {
380                initialized += 1;
381            });
382        }
383
384        assert_eq!(initialized, 1);
385    }
386
387    #[tokio::test]
388    async fn init_mut_async_only_once() {
389        let mut initialized = 0;
390        let mut init_once = InitOnce::new();
391
392        for _ in 0..10 {
393            init_once
394                .init_async(async {
395                    initialized += 1;
396                })
397                .await;
398        }
399
400        assert_eq!(initialized, 1);
401    }
402
403    #[tokio::test]
404    async fn dropped_once_if_init() {
405        let mut init_once = Arc::new(InitOnce::new());
406        let count = Arc::new(Mutex::new(0));
407
408        assert_eq!(
409            *Arc::get_mut(&mut init_once).unwrap().state.get_mut(),
410            init_once_state::EMPTY
411        );
412
413        let tasks: Vec<_> = (0..10)
414            .map(|_| {
415                let init_once = Arc::clone(&init_once);
416                let count = Arc::clone(&count);
417
418                tokio::spawn(async move {
419                    if let InitState::Polling(mut poller) = init_once.state() {
420                        let TrackDrop {
421                            count: current_count,
422                        } = poller.init_async(future::ready(TrackDrop { count })).await;
423
424                        assert_eq!(*current_count.lock().unwrap(), 0);
425                    }
426                })
427            })
428            .collect();
429
430        for handle in tasks {
431            handle.await.unwrap();
432        }
433
434        assert_eq!(*count.lock().unwrap(), 0);
435        assert_eq!(
436            *Arc::get_mut(&mut init_once).unwrap().state.get_mut(),
437            init_once_state::INITIALIZED
438        );
439
440        drop(init_once);
441        assert_eq!(*count.lock().unwrap(), 1);
442    }
443
444    #[test]
445    fn never_poll_init() {
446        let mut init_once = Arc::new(InitOnce::<()>::new());
447        let count = Arc::new(Mutex::new(0));
448
449        assert_eq!(
450            *Arc::get_mut(&mut init_once).unwrap().state.get_mut(),
451            init_once_state::EMPTY
452        );
453
454        assert_eq!(*count.lock().unwrap(), 0);
455
456        let threads: Vec<_> = (0..10)
457            .map(|_| {
458                let init_once = Arc::clone(&init_once);
459                let count = Arc::clone(&count);
460
461                thread::spawn(move || {
462                    if matches!(init_once.state(), InitState::Polling(_)) {
463                        drop(TrackDrop { count });
464                    }
465                })
466            })
467            .collect();
468
469        for handle in threads {
470            handle.join().unwrap();
471        }
472
473        assert_eq!(*count.lock().unwrap(), 1);
474
475        assert_eq!(
476            *Arc::get_mut(&mut init_once).unwrap().state.get_mut(),
477            init_once_state::INITIALIZING
478        );
479
480        for _ in 0..50 {
481            assert!(matches!(init_once.state(), InitState::Initializing));
482        }
483
484        drop(init_once);
485    }
486
487    #[test]
488    fn poll_init_only_once() {
489        let mut once = InitOnce::new();
490        let count = Arc::new(Mutex::new(0));
491
492        assert_eq!(*count.lock().unwrap(), 0);
493
494        if let InitState::Polling(mut poller) = once.state() {
495            for i in 0..10 {
496                _ = poller.poll_init(|| {
497                    if i == 0 {
498                        Poll::Ready((
499                            420,
500                            TrackDrop {
501                                count: Arc::clone(&count),
502                            },
503                        ))
504                    } else {
505                        unreachable!()
506                    }
507                });
508            }
509        }
510
511        let value = once.init(|| unreachable!());
512        assert_eq!(value.0, 420);
513
514        assert_eq!(*count.lock().unwrap(), 0);
515        drop(once);
516        assert_eq!(*count.lock().unwrap(), 1);
517    }
518
519    #[tokio::test]
520    async fn init_async_drop_future() {
521        let mut once = InitOnce::new();
522        let mut completed = false;
523
524        {
525            let mut future = once.init_async(async {
526                tokio::task::yield_now().await;
527                completed = true;
528                420
529            });
530            let mut pinned_future = core::pin::pin!(future);
531
532            std::future::poll_fn(|cx| match pinned_future.as_mut().poll(cx) {
533                Poll::Ready(_) => unreachable!(),
534                Poll::Pending => Poll::Ready(()),
535            })
536            .await;
537        }
538
539        assert!(!completed, "future was dropped before completing");
540        assert_ne!(
541            *once.state.get_mut(),
542            init_once_state::INITIALIZED,
543            "cell should not have been initialized",
544        );
545    }
546}