Skip to main content

ntex_rt/
task.rs

1use std::{sync::Arc, sync::atomic::AtomicUsize, sync::atomic::Ordering};
2
3// The Callbacks static holds a pointer to the global callbacks. It is protected by
4// the STATE static which determines whether `CBS` has been initialized yet.
5static mut CBS: Option<Arc<dyn CallbacksApi>> = None;
6
7static STATE: AtomicUsize = AtomicUsize::new(0);
8
9// There are three different states that we care about: the callback's
10// uninitialized, the callback's initializing (set_cbs's been called but
11// CBS hasn't actually been set yet), or the callbacks's active.
12const UNINITIALIZED: usize = 0;
13const INITIALIZING: usize = 1;
14const INITIALIZED: usize = 2;
15
16trait CallbacksApi {
17    fn before(&self) -> Option<*const ()>;
18    fn enter(&self, _: *const ()) -> *const ();
19    fn exit(&self, _: *const ());
20    fn after(&self, _: *const ());
21}
22
23#[allow(clippy::struct_field_names)]
24struct Callbacks<A, B, C, D> {
25    f_before: A,
26    f_enter: B,
27    f_exit: C,
28    f_after: D,
29}
30
31impl<A, B, C, D> CallbacksApi for Callbacks<A, B, C, D>
32where
33    A: Fn() -> Option<*const ()> + 'static,
34    B: Fn(*const ()) -> *const () + 'static,
35    C: Fn(*const ()) + 'static,
36    D: Fn(*const ()) + 'static,
37{
38    fn before(&self) -> Option<*const ()> {
39        (self.f_before)()
40    }
41    fn enter(&self, d: *const ()) -> *const () {
42        (self.f_enter)(d)
43    }
44    fn exit(&self, d: *const ()) {
45        (self.f_exit)(d);
46    }
47    fn after(&self, d: *const ()) {
48        (self.f_after)(d);
49    }
50}
51
52pub(crate) struct Data {
53    cb: &'static dyn CallbacksApi,
54    ptr: *const (),
55}
56
57impl Data {
58    #[allow(clippy::if_not_else)]
59    pub(crate) fn load() -> Option<Data> {
60        // We only care about validity of state nothing else
61        let cb = if STATE.load(Ordering::Relaxed) != INITIALIZED {
62            None
63        } else {
64            #[allow(static_mut_refs)]
65            unsafe {
66                Some(CBS.as_ref().map(AsRef::as_ref).unwrap())
67            }
68        };
69
70        if let Some(cb) = cb
71            && let Some(ptr) = cb.before()
72        {
73            return Some(Data { cb, ptr });
74        }
75        None
76    }
77
78    pub(crate) fn run<F, R>(&mut self, f: F) -> R
79    where
80        F: FnOnce() -> R,
81    {
82        let ptr = self.cb.enter(self.ptr);
83        let result = f();
84        self.cb.exit(ptr);
85        result
86    }
87}
88
89impl Drop for Data {
90    fn drop(&mut self) {
91        self.cb.after(self.ptr);
92    }
93}
94
95/// # Safety
96///
97/// The user must ensure that the pointer returned by `before` has a `'static` lifetime.
98/// This pointer will be owned by the spawned task for the duration of that task, and
99/// ownership will be returned to the user at the end of the task via `after`.
100/// The pointer remains opaque to the runtime.
101///
102/// # Panics
103///
104/// Panics if task callbacks have already been set.
105pub unsafe fn task_callbacks<FBefore, FEnter, FExit, FAfter>(
106    f_before: FBefore,
107    f_enter: FEnter,
108    f_exit: FExit,
109    f_after: FAfter,
110) where
111    FBefore: Fn() -> Option<*const ()> + 'static + Sync,
112    FEnter: Fn(*const ()) -> *const () + 'static + Sync,
113    FExit: Fn(*const ()) + 'static + Sync,
114    FAfter: Fn(*const ()) + 'static + Sync,
115{
116    let new = Arc::new(Callbacks {
117        f_before,
118        f_enter,
119        f_exit,
120        f_after,
121    });
122    let _ = set_cbs(new);
123}
124
125/// # Safety
126///
127/// The user must ensure that the pointer returned by `before` has a `'static` lifetime.
128/// This pointer will be owned by the spawned task for the duration of that task, and
129/// ownership will be returned to the user at the end of the task via `after`.
130/// The pointer remains opaque to the runtime.
131///
132/// Returns false if task callbacks have already been set.
133pub unsafe fn task_opt_callbacks<FBefore, FEnter, FExit, FAfter>(
134    f_before: FBefore,
135    f_enter: FEnter,
136    f_exit: FExit,
137    f_after: FAfter,
138) -> bool
139where
140    FBefore: Fn() -> Option<*const ()> + Sync + 'static,
141    FEnter: Fn(*const ()) -> *const () + Sync + 'static,
142    FExit: Fn(*const ()) + Sync + 'static,
143    FAfter: Fn(*const ()) + Sync + 'static,
144{
145    let new = Arc::new(Callbacks {
146        f_before,
147        f_enter,
148        f_exit,
149        f_after,
150    });
151    set_cbs(new).is_ok()
152}
153
154fn set_cbs(cbs: Arc<dyn CallbacksApi>) -> Result<(), ()> {
155    match STATE.compare_exchange(
156        UNINITIALIZED,
157        INITIALIZING,
158        Ordering::Acquire,
159        Ordering::Relaxed,
160    ) {
161        Ok(UNINITIALIZED) => {
162            unsafe {
163                CBS = Some(cbs);
164            }
165            STATE.store(INITIALIZED, Ordering::Release);
166            Ok(())
167        }
168        Err(INITIALIZING) => {
169            while STATE.load(Ordering::Relaxed) == INITIALIZING {
170                std::hint::spin_loop();
171            }
172            Err(())
173        }
174        _ => Err(()),
175    }
176}