kayrx_karx/
join_handle.rs

1use core::fmt;
2use core::future::Future;
3use core::marker::{PhantomData, Unpin};
4use core::pin::Pin;
5use core::ptr::NonNull;
6use core::sync::atomic::Ordering;
7use core::task::{Context, Poll, Waker};
8
9use crate::header::Header;
10use crate::state::*;
11
12/// A handle that awaits the result of a task.
13///
14/// This type is a future that resolves to an `Option<R>` where:
15///
16/// * `None` indicates the task has panicked or was canceled.
17/// * `Some(result)` indicates the task has completed with `result` of type `R`.
18pub struct JoinHandle<R, T> {
19    /// A raw task pointer.
20    pub(crate) raw_task: NonNull<()>,
21
22    /// A marker capturing generic types `R` and `T`.
23    pub(crate) _marker: PhantomData<(R, T)>,
24}
25
26unsafe impl<R: Send, T> Send for JoinHandle<R, T> {}
27unsafe impl<R, T> Sync for JoinHandle<R, T> {}
28
29impl<R, T> Unpin for JoinHandle<R, T> {}
30
31impl<R, T> JoinHandle<R, T> {
32    /// Cancels the task.
33    ///
34    /// If the task has already completed, calling this method will have no effect.
35    ///
36    /// When a task is canceled, its future will not be polled again.
37    pub fn cancel(&self) {
38        let ptr = self.raw_task.as_ptr();
39        let header = ptr as *const Header;
40
41        unsafe {
42            let mut state = (*header).state.load(Ordering::Acquire);
43
44            loop {
45                // If the task has been completed or closed, it can't be canceled.
46                if state & (COMPLETED | CLOSED) != 0 {
47                    break;
48                }
49
50                // If the task is not scheduled nor running, we'll need to schedule it.
51                let new = if state & (SCHEDULED | RUNNING) == 0 {
52                    (state | SCHEDULED | CLOSED) + REFERENCE
53                } else {
54                    state | CLOSED
55                };
56
57                // Mark the task as closed.
58                match (*header).state.compare_exchange_weak(
59                    state,
60                    new,
61                    Ordering::AcqRel,
62                    Ordering::Acquire,
63                ) {
64                    Ok(_) => {
65                        // If the task is not scheduled nor running, schedule it one more time so
66                        // that its future gets dropped by the executor.
67                        if state & (SCHEDULED | RUNNING) == 0 {
68                            ((*header).vtable.schedule)(ptr);
69                        }
70
71                        // Notify the awaiter that the task has been closed.
72                        if state & AWAITER != 0 {
73                            (*header).notify(None);
74                        }
75
76                        break;
77                    }
78                    Err(s) => state = s,
79                }
80            }
81        }
82    }
83
84    /// Returns a reference to the tag stored inside the task.
85    pub fn tag(&self) -> &T {
86        let offset = Header::offset_tag::<T>();
87        let ptr = self.raw_task.as_ptr();
88
89        unsafe {
90            let raw = (ptr as *mut u8).add(offset) as *const T;
91            &*raw
92        }
93    }
94
95    /// Returns a waker associated with the task.
96    pub fn waker(&self) -> Waker {
97        let ptr = self.raw_task.as_ptr();
98        let header = ptr as *const Header;
99
100        unsafe {
101            let raw_waker = ((*header).vtable.clone_waker)(ptr);
102            Waker::from_raw(raw_waker)
103        }
104    }
105}
106
107impl<R, T> Drop for JoinHandle<R, T> {
108    fn drop(&mut self) {
109        let ptr = self.raw_task.as_ptr();
110        let header = ptr as *const Header;
111
112        // A place where the output will be stored in case it needs to be dropped.
113        let mut output = None;
114
115        unsafe {
116            // Optimistically assume the `JoinHandle` is being dropped just after creating the
117            // task. This is a common case so if the handle is not used, the overhead of it is only
118            // one compare-exchange operation.
119            if let Err(mut state) = (*header).state.compare_exchange_weak(
120                SCHEDULED | HANDLE | REFERENCE,
121                SCHEDULED | REFERENCE,
122                Ordering::AcqRel,
123                Ordering::Acquire,
124            ) {
125                loop {
126                    // If the task has been completed but not yet closed, that means its output
127                    // must be dropped.
128                    if state & COMPLETED != 0 && state & CLOSED == 0 {
129                        // Mark the task as closed in order to grab its output.
130                        match (*header).state.compare_exchange_weak(
131                            state,
132                            state | CLOSED,
133                            Ordering::AcqRel,
134                            Ordering::Acquire,
135                        ) {
136                            Ok(_) => {
137                                // Read the output.
138                                output =
139                                    Some((((*header).vtable.get_output)(ptr) as *mut R).read());
140
141                                // Update the state variable because we're continuing the loop.
142                                state |= CLOSED;
143                            }
144                            Err(s) => state = s,
145                        }
146                    } else {
147                        // If this is the last reference to the task and it's not closed, then
148                        // close it and schedule one more time so that its future gets dropped by
149                        // the executor.
150                        let new = if state & (!(REFERENCE - 1) | CLOSED) == 0 {
151                            SCHEDULED | CLOSED | REFERENCE
152                        } else {
153                            state & !HANDLE
154                        };
155
156                        // Unset the handle flag.
157                        match (*header).state.compare_exchange_weak(
158                            state,
159                            new,
160                            Ordering::AcqRel,
161                            Ordering::Acquire,
162                        ) {
163                            Ok(_) => {
164                                // If this is the last reference to the task, we need to either
165                                // schedule dropping its future or destroy it.
166                                if state & !(REFERENCE - 1) == 0 {
167                                    if state & CLOSED == 0 {
168                                        ((*header).vtable.schedule)(ptr);
169                                    } else {
170                                        ((*header).vtable.destroy)(ptr);
171                                    }
172                                }
173
174                                break;
175                            }
176                            Err(s) => state = s,
177                        }
178                    }
179                }
180            }
181        }
182
183        // Drop the output if it was taken out of the task.
184        drop(output);
185    }
186}
187
188impl<R, T> Future for JoinHandle<R, T> {
189    type Output = Option<R>;
190
191    fn poll(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> {
192        let ptr = self.raw_task.as_ptr();
193        let header = ptr as *const Header;
194
195        unsafe {
196            let mut state = (*header).state.load(Ordering::Acquire);
197
198            loop {
199                // If the task has been closed, notify the awaiter and return `None`.
200                if state & CLOSED != 0 {
201                    // If the task is scheduled or running, we need to wait until its future is
202                    // dropped.
203                    if state & (SCHEDULED | RUNNING) != 0 {
204                        // Replace the waker with one associated with the current task.
205                        (*header).register(cx.waker());
206
207                        // Reload the state after registering. It is possible changes occurred just
208                        // before registration so we need to check for that.
209                        state = (*header).state.load(Ordering::Acquire);
210
211                        // If the task is still scheduled or running, we need to wait because its
212                        // future is not dropped yet.
213                        if state & (SCHEDULED | RUNNING) != 0 {
214                            return Poll::Pending;
215                        }
216                    }
217
218                    // Even though the awaiter is most likely the current task, it could also be
219                    // another task.
220                    (*header).notify(Some(cx.waker()));
221                    return Poll::Ready(None);
222                }
223
224                // If the task is not completed, register the current task.
225                if state & COMPLETED == 0 {
226                    // Replace the waker with one associated with the current task.
227                    (*header).register(cx.waker());
228
229                    // Reload the state after registering. It is possible that the task became
230                    // completed or closed just before registration so we need to check for that.
231                    state = (*header).state.load(Ordering::Acquire);
232
233                    // If the task has been closed, restart.
234                    if state & CLOSED != 0 {
235                        continue;
236                    }
237
238                    // If the task is still not completed, we're blocked on it.
239                    if state & COMPLETED == 0 {
240                        return Poll::Pending;
241                    }
242                }
243
244                // Since the task is now completed, mark it as closed in order to grab its output.
245                match (*header).state.compare_exchange(
246                    state,
247                    state | CLOSED,
248                    Ordering::AcqRel,
249                    Ordering::Acquire,
250                ) {
251                    Ok(_) => {
252                        // Notify the awaiter. Even though the awaiter is most likely the current
253                        // task, it could also be another task.
254                        if state & AWAITER != 0 {
255                            (*header).notify(Some(cx.waker()));
256                        }
257
258                        // Take the output from the task.
259                        let output = ((*header).vtable.get_output)(ptr) as *mut R;
260                        return Poll::Ready(Some(output.read()));
261                    }
262                    Err(s) => state = s,
263                }
264            }
265        }
266    }
267}
268
269impl<R, T> fmt::Debug for JoinHandle<R, T> {
270    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
271        let ptr = self.raw_task.as_ptr();
272        let header = ptr as *const Header;
273
274        f.debug_struct("JoinHandle")
275            .field("header", unsafe { &(*header) })
276            .finish()
277    }
278}