1use std::{sync::Arc, sync::atomic::AtomicUsize, sync::atomic::Ordering};
2
3static mut CBS: Option<Arc<dyn CallbacksApi>> = None;
6
7static STATE: AtomicUsize = AtomicUsize::new(0);
8
9const UNINITIALIZED: usize = 0;
13const INITIALIZING: usize = 1;
14const INITIALIZED: usize = 2;
15
16trait CallbacksApi {
17 fn before(&self) -> Option<*const ()>;
18 fn enter(&self, _: *const ()) -> *const ();
19 fn exit(&self, _: *const ());
20 fn after(&self, _: *const ());
21}
22
23#[allow(clippy::struct_field_names)]
24struct Callbacks<A, B, C, D> {
25 f_before: A,
26 f_enter: B,
27 f_exit: C,
28 f_after: D,
29}
30
31impl<A, B, C, D> CallbacksApi for Callbacks<A, B, C, D>
32where
33 A: Fn() -> Option<*const ()> + 'static,
34 B: Fn(*const ()) -> *const () + 'static,
35 C: Fn(*const ()) + 'static,
36 D: Fn(*const ()) + 'static,
37{
38 fn before(&self) -> Option<*const ()> {
39 (self.f_before)()
40 }
41 fn enter(&self, d: *const ()) -> *const () {
42 (self.f_enter)(d)
43 }
44 fn exit(&self, d: *const ()) {
45 (self.f_exit)(d);
46 }
47 fn after(&self, d: *const ()) {
48 (self.f_after)(d);
49 }
50}
51
52pub(crate) struct Data {
53 cb: &'static dyn CallbacksApi,
54 ptr: *const (),
55}
56
57impl Data {
58 #[allow(clippy::if_not_else)]
59 pub(crate) fn load() -> Option<Data> {
60 let cb = if STATE.load(Ordering::Acquire) != INITIALIZED {
64 None
65 } else {
66 #[allow(static_mut_refs)]
67 unsafe {
68 Some(CBS.as_ref().map(AsRef::as_ref).unwrap())
69 }
70 };
71
72 if let Some(cb) = cb
73 && let Some(ptr) = cb.before()
74 {
75 return Some(Data { cb, ptr });
76 }
77 None
78 }
79
80 pub(crate) fn run<F, R>(&mut self, f: F) -> R
81 where
82 F: FnOnce() -> R,
83 {
84 let ptr = self.cb.enter(self.ptr);
85 let result = f();
86 self.cb.exit(ptr);
87 result
88 }
89}
90
91impl Drop for Data {
92 fn drop(&mut self) {
93 self.cb.after(self.ptr);
94 }
95}
96
97pub unsafe fn task_callbacks<FBefore, FEnter, FExit, FAfter>(
108 f_before: FBefore,
109 f_enter: FEnter,
110 f_exit: FExit,
111 f_after: FAfter,
112) where
113 FBefore: Fn() -> Option<*const ()> + 'static + Sync,
114 FEnter: Fn(*const ()) -> *const () + 'static + Sync,
115 FExit: Fn(*const ()) + 'static + Sync,
116 FAfter: Fn(*const ()) + 'static + Sync,
117{
118 let new = Arc::new(Callbacks {
119 f_before,
120 f_enter,
121 f_exit,
122 f_after,
123 });
124 let _ = set_cbs(new);
125}
126
127pub unsafe fn task_opt_callbacks<FBefore, FEnter, FExit, FAfter>(
136 f_before: FBefore,
137 f_enter: FEnter,
138 f_exit: FExit,
139 f_after: FAfter,
140) -> bool
141where
142 FBefore: Fn() -> Option<*const ()> + Sync + 'static,
143 FEnter: Fn(*const ()) -> *const () + Sync + 'static,
144 FExit: Fn(*const ()) + Sync + 'static,
145 FAfter: Fn(*const ()) + Sync + 'static,
146{
147 let new = Arc::new(Callbacks {
148 f_before,
149 f_enter,
150 f_exit,
151 f_after,
152 });
153 set_cbs(new).is_ok()
154}
155
156fn set_cbs(cbs: Arc<dyn CallbacksApi>) -> Result<(), ()> {
157 match STATE.compare_exchange(
158 UNINITIALIZED,
159 INITIALIZING,
160 Ordering::Acquire,
161 Ordering::Relaxed,
162 ) {
163 Ok(UNINITIALIZED) => {
164 unsafe {
165 CBS = Some(cbs);
166 }
167 STATE.store(INITIALIZED, Ordering::Release);
168 Ok(())
169 }
170 Err(INITIALIZING) => {
171 while STATE.load(Ordering::Relaxed) == INITIALIZING {
172 std::hint::spin_loop();
173 }
174 Err(())
175 }
176 _ => Err(()),
177 }
178}