1#![cfg_attr(not(test), no_std)]
7#![deny(missing_docs)]
8
9use core::cell::UnsafeCell;
10use core::future::{self, Future};
11use core::hint::unreachable_unchecked;
12use core::mem::{needs_drop, MaybeUninit};
13use core::task::Poll;
14
15use portable_atomic::{self as atomic, AtomicUsize};
16
17mod init_once_state {
18 pub const EMPTY: usize = 0;
20 pub const INITIALIZING: usize = 1;
22 pub const INITIALIZED: usize = 2;
24}
25
26#[derive(Debug)]
28pub enum InitState<'a, T> {
29 Initializing,
31 Initialized(&'a T),
33 Polling(PollInit<'a, T>),
35}
36
37#[derive(Debug)]
41pub struct InitOnce<T> {
42 cell: UnsafeCell<MaybeUninit<T>>,
43 state: AtomicUsize,
44}
45
46#[derive(Debug)]
49pub struct PollInit<'a, T> {
50 polled_to_completion: bool,
51 init_once: &'a InitOnce<T>,
52}
53
54unsafe impl<T: Sync> Sync for InitOnce<T> {}
59
60impl<T> Drop for InitOnce<T> {
61 fn drop(&mut self) {
64 if needs_drop::<T>() && *self.state.get_mut() == init_once_state::INITIALIZED {
69 unsafe {
71 self.cell.get_mut().assume_init_drop();
72 }
73 }
74 }
75}
76
77impl<T> Default for InitOnce<T> {
78 #[inline]
79 fn default() -> Self {
80 Self::new()
81 }
82}
83
84impl<T> InitOnce<T> {
85 pub const fn new() -> Self {
87 Self {
88 cell: UnsafeCell::new(MaybeUninit::uninit()),
89 state: AtomicUsize::new(init_once_state::EMPTY),
90 }
91 }
92
93 #[must_use]
94 fn poll_init_begin(&self) -> PollInit<'_, T> {
95 PollInit {
96 init_once: self,
97 polled_to_completion: false,
98 }
99 }
100
101 #[must_use = "The state of an InitOnce (i.e. InitState) must always be consumed. If you do \
106 not poll the value initializer to completion, the value will never be initialized."]
107 #[inline]
108 pub fn state(&self) -> InitState<'_, T> {
109 self.state
110 .compare_exchange(
111 init_once_state::EMPTY,
112 init_once_state::INITIALIZING,
113 atomic::Ordering::SeqCst,
114 atomic::Ordering::SeqCst,
115 )
116 .map_or_else(
117 |current_value| match current_value {
118 init_once_state::INITIALIZING => InitState::Initializing,
119 init_once_state::INITIALIZED => {
120 InitState::Initialized({
121 unsafe { (*self.cell.get()).assume_init_ref() }
124 })
125 }
126 _ => {
127 unsafe { unreachable_unchecked() }
132 }
133 },
134 |_| unlikely_call(|| InitState::Polling(self.poll_init_begin())),
135 )
136 }
137
138 #[inline]
140 pub fn try_init<F>(&self, mut init: F) -> Option<&T>
141 where
142 F: FnMut() -> T,
143 {
144 match self.state() {
145 InitState::Initialized(value) => Some(value),
146 InitState::Initializing => None,
147 InitState::Polling(mut poller) => match poller.poll_init(|| Poll::Ready(init())) {
148 Poll::Ready(value) => Some(value),
149 Poll::Pending => {
150 unsafe { unreachable_unchecked() }
153 }
154 },
155 }
156 }
157
158 pub async fn try_init_async<F>(&self, init: F) -> Option<&T>
160 where
161 F: Future<Output = T>,
162 {
163 match self.state() {
164 InitState::Initialized(value) => Some(value),
165 InitState::Initializing => None,
166 InitState::Polling(mut poller) => Some(poller.init_async(init).await),
167 }
168 }
169
170 pub fn init<F>(&mut self, mut init: F) -> &mut T
172 where
173 F: FnMut() -> T,
174 {
175 let maybe_uninit = self.cell.get_mut();
176
177 if *self.state.get_mut() != init_once_state::INITIALIZED {
178 unlikely_call(|| {
179 maybe_uninit.write(init());
180 *self.state.get_mut() = init_once_state::INITIALIZED;
181 });
182 }
183
184 unsafe { maybe_uninit.assume_init_mut() }
188 }
189
190 pub async fn init_async<F>(&mut self, init: F) -> &mut T
192 where
193 F: Future<Output = T>,
194 {
195 let maybe_uninit = self.cell.get_mut();
196
197 if *self.state.get_mut() != init_once_state::INITIALIZED {
198 unlikely_call(|| async {
199 maybe_uninit.write(init.await);
200 *self.state.get_mut() = init_once_state::INITIALIZED;
201 })
202 .await;
203 }
204
205 unsafe { maybe_uninit.assume_init_mut() }
209 }
210}
211
212impl<'init_once, T> PollInit<'init_once, T> {
213 pub async fn init_async<F>(&mut self, mut init: F) -> &'init_once T
215 where
216 F: Future<Output = T>,
217 {
218 let mut pinned_init = core::pin::pin!(init);
219 future::poll_fn(|cx| self.poll_init(|| pinned_init.as_mut().poll(cx))).await
220 }
221
222 pub fn poll_init<F>(&mut self, mut init: F) -> Poll<&'init_once T>
224 where
225 F: FnMut() -> Poll<T>,
226 {
227 if self.polled_to_completion {
228 return unlikely_call(|| {
229 Poll::Ready({
230 unsafe { (*self.init_once.cell.get()).assume_init_ref() }
235 })
236 });
237 }
238
239 let value = core::task::ready!(init());
240
241 let slot = unsafe { (*self.init_once.cell.get()).as_mut_ptr() };
247
248 unsafe {
250 core::ptr::write(slot, value);
251 }
252
253 self.init_once
254 .state
255 .store(init_once_state::INITIALIZED, atomic::Ordering::SeqCst);
256
257 self.polled_to_completion = true;
258
259 Poll::Ready({
260 unsafe { (*self.init_once.cell.get()).assume_init_ref() }
263 })
264 }
265}
266
267#[cold]
268#[inline(never)]
269fn unlikely_call<T, F: FnOnce() -> T>(f: F) -> T {
270 f()
271}
272
273#[cfg(test)]
274mod tests {
275 use std::sync::{Arc, Mutex};
276 use std::thread;
277
278 use super::*;
279
280 struct TrackDrop {
281 count: Arc<Mutex<usize>>,
282 }
283
284 impl Drop for TrackDrop {
285 fn drop(&mut self) {
286 *self.count.lock().unwrap() += 1;
287 }
288 }
289
290 #[test]
291 fn try_init_wont_block() {
292 struct Shared {
293 init_once: InitOnce<()>,
294 thread_barrier: std::sync::Barrier,
295 init_barrier: std::sync::Barrier,
296 }
297
298 let shared = Arc::new(Shared {
299 init_once: InitOnce::new(),
300 thread_barrier: std::sync::Barrier::new(2),
301 init_barrier: std::sync::Barrier::new(2),
302 });
303
304 let shared2 = Arc::clone(&shared);
305
306 let handle = std::thread::spawn(move || {
307 shared2.thread_barrier.wait();
308
309 assert!(shared2
310 .init_once
311 .try_init(|| {
312 shared2.init_barrier.wait();
313 })
314 .is_some());
315 });
316
317 shared.thread_barrier.wait();
318 std::thread::sleep(std::time::Duration::from_millis(50));
319 assert!(shared.init_once.try_init(|| panic!()).is_none());
320
321 shared.init_barrier.wait();
322 handle.join().unwrap();
323 assert!(shared.init_once.try_init(|| panic!()).is_some());
324 }
325
326 #[tokio::test]
327 async fn try_init_async_wont_block() {
328 struct Shared {
329 init_once: InitOnce<()>,
330 thread_barrier: tokio::sync::Barrier,
331 init_barrier: tokio::sync::Barrier,
332 }
333
334 let shared = Arc::new(Shared {
335 init_once: InitOnce::new(),
336 thread_barrier: tokio::sync::Barrier::new(2),
337 init_barrier: tokio::sync::Barrier::new(2),
338 });
339
340 let shared2 = Arc::clone(&shared);
341
342 let handle = tokio::spawn(async move {
343 shared2.thread_barrier.wait().await;
344
345 assert!(shared2
346 .init_once
347 .try_init_async(async {
348 shared2.init_barrier.wait().await;
349 })
350 .await
351 .is_some());
352 });
353
354 shared.thread_barrier.wait().await;
355 tokio::time::sleep(tokio::time::Duration::from_millis(50)).await;
356 assert!(shared.init_once.try_init(|| panic!()).is_none());
357 assert!(shared
358 .init_once
359 .try_init_async(async { panic!() })
360 .await
361 .is_none());
362
363 shared.init_barrier.wait().await;
364 handle.await.unwrap();
365 assert!(shared.init_once.try_init(|| panic!()).is_some());
366 assert!(shared
367 .init_once
368 .try_init_async(async { panic!() })
369 .await
370 .is_some());
371 }
372
373 #[test]
374 fn init_mut_only_once() {
375 let mut initialized = 0;
376 let mut init_once = InitOnce::new();
377
378 for _ in 0..10 {
379 init_once.init(|| {
380 initialized += 1;
381 });
382 }
383
384 assert_eq!(initialized, 1);
385 }
386
387 #[tokio::test]
388 async fn init_mut_async_only_once() {
389 let mut initialized = 0;
390 let mut init_once = InitOnce::new();
391
392 for _ in 0..10 {
393 init_once
394 .init_async(async {
395 initialized += 1;
396 })
397 .await;
398 }
399
400 assert_eq!(initialized, 1);
401 }
402
403 #[tokio::test]
404 async fn dropped_once_if_init() {
405 let mut init_once = Arc::new(InitOnce::new());
406 let count = Arc::new(Mutex::new(0));
407
408 assert_eq!(
409 *Arc::get_mut(&mut init_once).unwrap().state.get_mut(),
410 init_once_state::EMPTY
411 );
412
413 let tasks: Vec<_> = (0..10)
414 .map(|_| {
415 let init_once = Arc::clone(&init_once);
416 let count = Arc::clone(&count);
417
418 tokio::spawn(async move {
419 if let InitState::Polling(mut poller) = init_once.state() {
420 let TrackDrop {
421 count: current_count,
422 } = poller.init_async(future::ready(TrackDrop { count })).await;
423
424 assert_eq!(*current_count.lock().unwrap(), 0);
425 }
426 })
427 })
428 .collect();
429
430 for handle in tasks {
431 handle.await.unwrap();
432 }
433
434 assert_eq!(*count.lock().unwrap(), 0);
435 assert_eq!(
436 *Arc::get_mut(&mut init_once).unwrap().state.get_mut(),
437 init_once_state::INITIALIZED
438 );
439
440 drop(init_once);
441 assert_eq!(*count.lock().unwrap(), 1);
442 }
443
444 #[test]
445 fn never_poll_init() {
446 let mut init_once = Arc::new(InitOnce::<()>::new());
447 let count = Arc::new(Mutex::new(0));
448
449 assert_eq!(
450 *Arc::get_mut(&mut init_once).unwrap().state.get_mut(),
451 init_once_state::EMPTY
452 );
453
454 assert_eq!(*count.lock().unwrap(), 0);
455
456 let threads: Vec<_> = (0..10)
457 .map(|_| {
458 let init_once = Arc::clone(&init_once);
459 let count = Arc::clone(&count);
460
461 thread::spawn(move || {
462 if matches!(init_once.state(), InitState::Polling(_)) {
463 drop(TrackDrop { count });
464 }
465 })
466 })
467 .collect();
468
469 for handle in threads {
470 handle.join().unwrap();
471 }
472
473 assert_eq!(*count.lock().unwrap(), 1);
474
475 assert_eq!(
476 *Arc::get_mut(&mut init_once).unwrap().state.get_mut(),
477 init_once_state::INITIALIZING
478 );
479
480 for _ in 0..50 {
481 assert!(matches!(init_once.state(), InitState::Initializing));
482 }
483
484 drop(init_once);
485 }
486
487 #[test]
488 fn poll_init_only_once() {
489 let mut once = InitOnce::new();
490 let count = Arc::new(Mutex::new(0));
491
492 assert_eq!(*count.lock().unwrap(), 0);
493
494 if let InitState::Polling(mut poller) = once.state() {
495 for i in 0..10 {
496 _ = poller.poll_init(|| {
497 if i == 0 {
498 Poll::Ready((
499 420,
500 TrackDrop {
501 count: Arc::clone(&count),
502 },
503 ))
504 } else {
505 unreachable!()
506 }
507 });
508 }
509 }
510
511 let value = once.init(|| unreachable!());
512 assert_eq!(value.0, 420);
513
514 assert_eq!(*count.lock().unwrap(), 0);
515 drop(once);
516 assert_eq!(*count.lock().unwrap(), 1);
517 }
518
519 #[tokio::test]
520 async fn init_async_drop_future() {
521 let mut once = InitOnce::new();
522 let mut completed = false;
523
524 {
525 let mut future = once.init_async(async {
526 tokio::task::yield_now().await;
527 completed = true;
528 420
529 });
530 let mut pinned_future = core::pin::pin!(future);
531
532 std::future::poll_fn(|cx| match pinned_future.as_mut().poll(cx) {
533 Poll::Ready(_) => unreachable!(),
534 Poll::Pending => Poll::Ready(()),
535 })
536 .await;
537 }
538
539 assert!(!completed, "future was dropped before completing");
540 assert_ne!(
541 *once.state.get_mut(),
542 init_once_state::INITIALIZED,
543 "cell should not have been initialized",
544 );
545 }
546}