use std::{sync::Arc, sync::atomic::AtomicUsize, sync::atomic::Ordering};
static mut CBS: Option<Arc<dyn CallbacksApi>> = None;
static STATE: AtomicUsize = AtomicUsize::new(0);
const UNINITIALIZED: usize = 0;
const INITIALIZING: usize = 1;
const INITIALIZED: usize = 2;
trait CallbacksApi {
fn before(&self) -> Option<*const ()>;
fn enter(&self, _: *const ()) -> *const ();
fn exit(&self, _: *const ());
fn after(&self, _: *const ());
}
#[allow(clippy::struct_field_names)]
struct Callbacks<A, B, C, D> {
f_before: A,
f_enter: B,
f_exit: C,
f_after: D,
}
impl<A, B, C, D> CallbacksApi for Callbacks<A, B, C, D>
where
A: Fn() -> Option<*const ()> + 'static,
B: Fn(*const ()) -> *const () + 'static,
C: Fn(*const ()) + 'static,
D: Fn(*const ()) + 'static,
{
fn before(&self) -> Option<*const ()> {
(self.f_before)()
}
fn enter(&self, d: *const ()) -> *const () {
(self.f_enter)(d)
}
fn exit(&self, d: *const ()) {
(self.f_exit)(d);
}
fn after(&self, d: *const ()) {
(self.f_after)(d);
}
}
pub(crate) struct Data {
cb: &'static dyn CallbacksApi,
ptr: *const (),
}
impl Data {
#[allow(clippy::if_not_else)]
pub(crate) fn load() -> Option<Data> {
let cb = if STATE.load(Ordering::Relaxed) != INITIALIZED {
None
} else {
#[allow(static_mut_refs)]
unsafe {
Some(CBS.as_ref().map(AsRef::as_ref).unwrap())
}
};
if let Some(cb) = cb
&& let Some(ptr) = cb.before()
{
return Some(Data { cb, ptr });
}
None
}
pub(crate) fn run<F, R>(&mut self, f: F) -> R
where
F: FnOnce() -> R,
{
let ptr = self.cb.enter(self.ptr);
let result = f();
self.cb.exit(ptr);
result
}
}
impl Drop for Data {
fn drop(&mut self) {
self.cb.after(self.ptr);
}
}
pub unsafe fn task_callbacks<FBefore, FEnter, FExit, FAfter>(
f_before: FBefore,
f_enter: FEnter,
f_exit: FExit,
f_after: FAfter,
) where
FBefore: Fn() -> Option<*const ()> + 'static + Sync,
FEnter: Fn(*const ()) -> *const () + 'static + Sync,
FExit: Fn(*const ()) + 'static + Sync,
FAfter: Fn(*const ()) + 'static + Sync,
{
let new = Arc::new(Callbacks {
f_before,
f_enter,
f_exit,
f_after,
});
let _ = set_cbs(new);
}
pub unsafe fn task_opt_callbacks<FBefore, FEnter, FExit, FAfter>(
f_before: FBefore,
f_enter: FEnter,
f_exit: FExit,
f_after: FAfter,
) -> bool
where
FBefore: Fn() -> Option<*const ()> + Sync + 'static,
FEnter: Fn(*const ()) -> *const () + Sync + 'static,
FExit: Fn(*const ()) + Sync + 'static,
FAfter: Fn(*const ()) + Sync + 'static,
{
let new = Arc::new(Callbacks {
f_before,
f_enter,
f_exit,
f_after,
});
set_cbs(new).is_ok()
}
fn set_cbs(cbs: Arc<dyn CallbacksApi>) -> Result<(), ()> {
match STATE.compare_exchange(
UNINITIALIZED,
INITIALIZING,
Ordering::Acquire,
Ordering::Relaxed,
) {
Ok(UNINITIALIZED) => {
unsafe {
CBS = Some(cbs);
}
STATE.store(INITIALIZED, Ordering::Release);
Ok(())
}
Err(INITIALIZING) => {
while STATE.load(Ordering::Relaxed) == INITIALIZING {
std::hint::spin_loop();
}
Err(())
}
_ => Err(()),
}
}