1use crate::sem::{Semaphore, Sem};
2
3use core::ptr;
4use core::sync::atomic::{AtomicU8, AtomicU32, AtomicPtr, Ordering};
5
6const INCOMPLETE: u8 = 0x0;
7const RUNNING: u8 = 0x1;
8const COMPLETE: u8 = 0x2;
9const FAIL: u8 = 0x4;
10
11const FAIL_MSG: &str = "Unable to initialize semaphore";
12
13struct StateGuard<'a> {
14 state: &'a AtomicU8,
15}
16
17impl<'a> Drop for StateGuard<'a> {
18 fn drop(&mut self) {
19 self.state.store(COMPLETE, Ordering::Release);
20 }
21}
22
23struct SemGuard {
24 sem: Sem,
25 waiting: AtomicU32,
26}
27
28impl SemGuard {
29 fn wait(&self) {
30 self.waiting.fetch_add(1, Ordering::Release);
31 self.sem.wait();
32 }
33}
34
35impl Drop for SemGuard {
36 fn drop(&mut self) {
37 for _ in 0..self.waiting.load(Ordering::Acquire) {
38 self.sem.signal();
39 }
40 }
41}
42
43pub struct Once {
45 sem: AtomicPtr<SemGuard>,
46 state: AtomicU8,
47}
48
49impl Once {
50 pub const fn new() -> Self {
52 Self {
53 sem: AtomicPtr::new(ptr::null_mut()),
54 state: AtomicU8::new(INCOMPLETE),
55 }
56 }
57
58 pub fn call_once<F: FnOnce()>(&self, cb: F) {
62 if self.is_completed() {
63 return;
64 }
65
66 let mut cb = Some(cb);
67 self.call_inner(move || match cb.take() {
68 Some(cb) => cb(),
69 None => unreach!()
70 });
71 }
72
73 #[inline]
74 pub fn is_completed(&self) -> bool {
76 self.state.load(Ordering::Acquire) == COMPLETE
77 }
78
79 #[cold]
80 fn call_inner_fail(&self) -> ! {
81 self.state.store(FAIL, Ordering::Acquire);
82 panic!(FAIL_MSG)
83 }
84
85 #[cold]
86 fn call_inner<F: FnMut()>(&self, mut cb: F) {
87 loop {
88 match self.state.load(Ordering::Acquire) {
89 COMPLETE => break,
90 FAIL => panic!(FAIL_MSG),
91 INCOMPLETE => {
92 if INCOMPLETE != self.state.compare_and_swap(INCOMPLETE, RUNNING, Ordering::Acquire) {
93 continue;
94 }
95
96 let sem = match Sem::new(0) {
97 Some(sem) => sem,
98 None => self.call_inner_fail(),
99 };
100
101 let mut sem_guard = SemGuard {
102 sem,
103 waiting: AtomicU32::new(0),
104 };
105 self.sem.store(&mut sem_guard as *mut _, Ordering::Release);
106
107 let _state_guard = StateGuard {
109 state: &self.state
110 };
111
112 cb();
113
114 },
115 state => {
116 debug_assert_eq!(state, RUNNING);
117
118 let mut sem = self.sem.load(Ordering::Acquire);
119 while sem.is_null() {
120 if self.state.load(Ordering::Acquire) == FAIL {
121 panic!(FAIL_MSG);
122 }
123
124 sem = self.sem.load(Ordering::Acquire);
125 core::sync::atomic::spin_loop_hint();
126 }
127
128 if self.state.load(Ordering::Acquire) != RUNNING {
129 unsafe {
130 (*sem).wait()
131 }
132 }
133 },
134 }
135 }
136 }
137}