Skip to main content

ntex_rt/
task.rs

1use std::{sync::Arc, sync::atomic::AtomicUsize, sync::atomic::Ordering};
2
3// The Callbacks static holds a pointer to the global logger. It is protected by
4// the STATE static which determines whether LOGGER has been initialized yet.
5static mut CBS: Option<Arc<dyn CallbacksApi>> = None;
6
7static STATE: AtomicUsize = AtomicUsize::new(0);
8
9// There are three different states that we care about: the logger's
10// uninitialized, the logger's initializing (set_logger's been called but
11// LOGGER hasn't actually been set yet), or the logger's active.
12const UNINITIALIZED: usize = 0;
13const INITIALIZING: usize = 1;
14const INITIALIZED: usize = 2;
15
16trait CallbacksApi {
17    fn before(&self) -> Option<*const ()>;
18    fn enter(&self, _: *const ()) -> *const ();
19    fn exit(&self, _: *const ());
20    fn after(&self, _: *const ());
21}
22
23#[allow(clippy::struct_field_names)]
24struct Callbacks<A, B, C, D> {
25    f_before: A,
26    f_enter: B,
27    f_exit: C,
28    f_after: D,
29}
30
31impl<A, B, C, D> CallbacksApi for Callbacks<A, B, C, D>
32where
33    A: Fn() -> Option<*const ()> + 'static,
34    B: Fn(*const ()) -> *const () + 'static,
35    C: Fn(*const ()) + 'static,
36    D: Fn(*const ()) + 'static,
37{
38    fn before(&self) -> Option<*const ()> {
39        (self.f_before)()
40    }
41    fn enter(&self, d: *const ()) -> *const () {
42        (self.f_enter)(d)
43    }
44    fn exit(&self, d: *const ()) {
45        (self.f_exit)(d);
46    }
47    fn after(&self, d: *const ()) {
48        (self.f_after)(d);
49    }
50}
51
52pub(crate) struct Data {
53    cb: &'static dyn CallbacksApi,
54    ptr: *const (),
55}
56
57impl Data {
58    #[allow(clippy::if_not_else)]
59    pub(crate) fn load() -> Option<Data> {
60        // Acquire memory ordering guarantees that current thread would see any
61        // memory writes that happened before store of the value
62        // into `STATE` with memory ordering `Release` or stronger.
63        let cb = if STATE.load(Ordering::Acquire) != INITIALIZED {
64            None
65        } else {
66            #[allow(static_mut_refs)]
67            unsafe {
68                Some(CBS.as_ref().map(AsRef::as_ref).unwrap())
69            }
70        };
71
72        if let Some(cb) = cb
73            && let Some(ptr) = cb.before()
74        {
75            return Some(Data { cb, ptr });
76        }
77        None
78    }
79
80    pub(crate) fn run<F, R>(&mut self, f: F) -> R
81    where
82        F: FnOnce() -> R,
83    {
84        let ptr = self.cb.enter(self.ptr);
85        let result = f();
86        self.cb.exit(ptr);
87        result
88    }
89}
90
91impl Drop for Data {
92    fn drop(&mut self) {
93        self.cb.after(self.ptr);
94    }
95}
96
97/// # Safety
98///
99/// The user must ensure that the pointer returned by `before` has a `'static` lifetime.
100/// This pointer will be owned by the spawned task for the duration of that task, and
101/// ownership will be returned to the user at the end of the task via `after`.
102/// The pointer remains opaque to the runtime.
103///
104/// # Panics
105///
106/// Panics if task callbacks have already been set.
107pub unsafe fn task_callbacks<FBefore, FEnter, FExit, FAfter>(
108    f_before: FBefore,
109    f_enter: FEnter,
110    f_exit: FExit,
111    f_after: FAfter,
112) where
113    FBefore: Fn() -> Option<*const ()> + 'static + Sync,
114    FEnter: Fn(*const ()) -> *const () + 'static + Sync,
115    FExit: Fn(*const ()) + 'static + Sync,
116    FAfter: Fn(*const ()) + 'static + Sync,
117{
118    let new = Arc::new(Callbacks {
119        f_before,
120        f_enter,
121        f_exit,
122        f_after,
123    });
124    let _ = set_cbs(new);
125}
126
127/// # Safety
128///
129/// The user must ensure that the pointer returned by `before` has a `'static` lifetime.
130/// This pointer will be owned by the spawned task for the duration of that task, and
131/// ownership will be returned to the user at the end of the task via `after`.
132/// The pointer remains opaque to the runtime.
133///
134/// Returns false if task callbacks have already been set.
135pub unsafe fn task_opt_callbacks<FBefore, FEnter, FExit, FAfter>(
136    f_before: FBefore,
137    f_enter: FEnter,
138    f_exit: FExit,
139    f_after: FAfter,
140) -> bool
141where
142    FBefore: Fn() -> Option<*const ()> + Sync + 'static,
143    FEnter: Fn(*const ()) -> *const () + Sync + 'static,
144    FExit: Fn(*const ()) + Sync + 'static,
145    FAfter: Fn(*const ()) + Sync + 'static,
146{
147    let new = Arc::new(Callbacks {
148        f_before,
149        f_enter,
150        f_exit,
151        f_after,
152    });
153    set_cbs(new).is_ok()
154}
155
156fn set_cbs(cbs: Arc<dyn CallbacksApi>) -> Result<(), ()> {
157    match STATE.compare_exchange(
158        UNINITIALIZED,
159        INITIALIZING,
160        Ordering::Acquire,
161        Ordering::Relaxed,
162    ) {
163        Ok(UNINITIALIZED) => {
164            unsafe {
165                CBS = Some(cbs);
166            }
167            STATE.store(INITIALIZED, Ordering::Release);
168            Ok(())
169        }
170        Err(INITIALIZING) => {
171            while STATE.load(Ordering::Relaxed) == INITIALIZING {
172                std::hint::spin_loop();
173            }
174            Err(())
175        }
176        _ => Err(()),
177    }
178}