kayrx_karx/join_handle.rs
1use core::fmt;
2use core::future::Future;
3use core::marker::{PhantomData, Unpin};
4use core::pin::Pin;
5use core::ptr::NonNull;
6use core::sync::atomic::Ordering;
7use core::task::{Context, Poll, Waker};
8
9use crate::header::Header;
10use crate::state::*;
11
12/// A handle that awaits the result of a task.
13///
14/// This type is a future that resolves to an `Option<R>` where:
15///
16/// * `None` indicates the task has panicked or was canceled.
17/// * `Some(result)` indicates the task has completed with `result` of type `R`.
18pub struct JoinHandle<R, T> {
19 /// A raw task pointer.
20 pub(crate) raw_task: NonNull<()>,
21
22 /// A marker capturing generic types `R` and `T`.
23 pub(crate) _marker: PhantomData<(R, T)>,
24}
25
26unsafe impl<R: Send, T> Send for JoinHandle<R, T> {}
27unsafe impl<R, T> Sync for JoinHandle<R, T> {}
28
29impl<R, T> Unpin for JoinHandle<R, T> {}
30
31impl<R, T> JoinHandle<R, T> {
32 /// Cancels the task.
33 ///
34 /// If the task has already completed, calling this method will have no effect.
35 ///
36 /// When a task is canceled, its future will not be polled again.
37 pub fn cancel(&self) {
38 let ptr = self.raw_task.as_ptr();
39 let header = ptr as *const Header;
40
41 unsafe {
42 let mut state = (*header).state.load(Ordering::Acquire);
43
44 loop {
45 // If the task has been completed or closed, it can't be canceled.
46 if state & (COMPLETED | CLOSED) != 0 {
47 break;
48 }
49
50 // If the task is not scheduled nor running, we'll need to schedule it.
51 let new = if state & (SCHEDULED | RUNNING) == 0 {
52 (state | SCHEDULED | CLOSED) + REFERENCE
53 } else {
54 state | CLOSED
55 };
56
57 // Mark the task as closed.
58 match (*header).state.compare_exchange_weak(
59 state,
60 new,
61 Ordering::AcqRel,
62 Ordering::Acquire,
63 ) {
64 Ok(_) => {
65 // If the task is not scheduled nor running, schedule it one more time so
66 // that its future gets dropped by the executor.
67 if state & (SCHEDULED | RUNNING) == 0 {
68 ((*header).vtable.schedule)(ptr);
69 }
70
71 // Notify the awaiter that the task has been closed.
72 if state & AWAITER != 0 {
73 (*header).notify(None);
74 }
75
76 break;
77 }
78 Err(s) => state = s,
79 }
80 }
81 }
82 }
83
84 /// Returns a reference to the tag stored inside the task.
85 pub fn tag(&self) -> &T {
86 let offset = Header::offset_tag::<T>();
87 let ptr = self.raw_task.as_ptr();
88
89 unsafe {
90 let raw = (ptr as *mut u8).add(offset) as *const T;
91 &*raw
92 }
93 }
94
95 /// Returns a waker associated with the task.
96 pub fn waker(&self) -> Waker {
97 let ptr = self.raw_task.as_ptr();
98 let header = ptr as *const Header;
99
100 unsafe {
101 let raw_waker = ((*header).vtable.clone_waker)(ptr);
102 Waker::from_raw(raw_waker)
103 }
104 }
105}
106
107impl<R, T> Drop for JoinHandle<R, T> {
108 fn drop(&mut self) {
109 let ptr = self.raw_task.as_ptr();
110 let header = ptr as *const Header;
111
112 // A place where the output will be stored in case it needs to be dropped.
113 let mut output = None;
114
115 unsafe {
116 // Optimistically assume the `JoinHandle` is being dropped just after creating the
117 // task. This is a common case so if the handle is not used, the overhead of it is only
118 // one compare-exchange operation.
119 if let Err(mut state) = (*header).state.compare_exchange_weak(
120 SCHEDULED | HANDLE | REFERENCE,
121 SCHEDULED | REFERENCE,
122 Ordering::AcqRel,
123 Ordering::Acquire,
124 ) {
125 loop {
126 // If the task has been completed but not yet closed, that means its output
127 // must be dropped.
128 if state & COMPLETED != 0 && state & CLOSED == 0 {
129 // Mark the task as closed in order to grab its output.
130 match (*header).state.compare_exchange_weak(
131 state,
132 state | CLOSED,
133 Ordering::AcqRel,
134 Ordering::Acquire,
135 ) {
136 Ok(_) => {
137 // Read the output.
138 output =
139 Some((((*header).vtable.get_output)(ptr) as *mut R).read());
140
141 // Update the state variable because we're continuing the loop.
142 state |= CLOSED;
143 }
144 Err(s) => state = s,
145 }
146 } else {
147 // If this is the last reference to the task and it's not closed, then
148 // close it and schedule one more time so that its future gets dropped by
149 // the executor.
150 let new = if state & (!(REFERENCE - 1) | CLOSED) == 0 {
151 SCHEDULED | CLOSED | REFERENCE
152 } else {
153 state & !HANDLE
154 };
155
156 // Unset the handle flag.
157 match (*header).state.compare_exchange_weak(
158 state,
159 new,
160 Ordering::AcqRel,
161 Ordering::Acquire,
162 ) {
163 Ok(_) => {
164 // If this is the last reference to the task, we need to either
165 // schedule dropping its future or destroy it.
166 if state & !(REFERENCE - 1) == 0 {
167 if state & CLOSED == 0 {
168 ((*header).vtable.schedule)(ptr);
169 } else {
170 ((*header).vtable.destroy)(ptr);
171 }
172 }
173
174 break;
175 }
176 Err(s) => state = s,
177 }
178 }
179 }
180 }
181 }
182
183 // Drop the output if it was taken out of the task.
184 drop(output);
185 }
186}
187
188impl<R, T> Future for JoinHandle<R, T> {
189 type Output = Option<R>;
190
191 fn poll(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> {
192 let ptr = self.raw_task.as_ptr();
193 let header = ptr as *const Header;
194
195 unsafe {
196 let mut state = (*header).state.load(Ordering::Acquire);
197
198 loop {
199 // If the task has been closed, notify the awaiter and return `None`.
200 if state & CLOSED != 0 {
201 // If the task is scheduled or running, we need to wait until its future is
202 // dropped.
203 if state & (SCHEDULED | RUNNING) != 0 {
204 // Replace the waker with one associated with the current task.
205 (*header).register(cx.waker());
206
207 // Reload the state after registering. It is possible changes occurred just
208 // before registration so we need to check for that.
209 state = (*header).state.load(Ordering::Acquire);
210
211 // If the task is still scheduled or running, we need to wait because its
212 // future is not dropped yet.
213 if state & (SCHEDULED | RUNNING) != 0 {
214 return Poll::Pending;
215 }
216 }
217
218 // Even though the awaiter is most likely the current task, it could also be
219 // another task.
220 (*header).notify(Some(cx.waker()));
221 return Poll::Ready(None);
222 }
223
224 // If the task is not completed, register the current task.
225 if state & COMPLETED == 0 {
226 // Replace the waker with one associated with the current task.
227 (*header).register(cx.waker());
228
229 // Reload the state after registering. It is possible that the task became
230 // completed or closed just before registration so we need to check for that.
231 state = (*header).state.load(Ordering::Acquire);
232
233 // If the task has been closed, restart.
234 if state & CLOSED != 0 {
235 continue;
236 }
237
238 // If the task is still not completed, we're blocked on it.
239 if state & COMPLETED == 0 {
240 return Poll::Pending;
241 }
242 }
243
244 // Since the task is now completed, mark it as closed in order to grab its output.
245 match (*header).state.compare_exchange(
246 state,
247 state | CLOSED,
248 Ordering::AcqRel,
249 Ordering::Acquire,
250 ) {
251 Ok(_) => {
252 // Notify the awaiter. Even though the awaiter is most likely the current
253 // task, it could also be another task.
254 if state & AWAITER != 0 {
255 (*header).notify(Some(cx.waker()));
256 }
257
258 // Take the output from the task.
259 let output = ((*header).vtable.get_output)(ptr) as *mut R;
260 return Poll::Ready(Some(output.read()));
261 }
262 Err(s) => state = s,
263 }
264 }
265 }
266 }
267}
268
269impl<R, T> fmt::Debug for JoinHandle<R, T> {
270 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
271 let ptr = self.raw_task.as_ptr();
272 let header = ptr as *const Header;
273
274 f.debug_struct("JoinHandle")
275 .field("header", unsafe { &(*header) })
276 .finish()
277 }
278}