1use std::{
2 future::Future,
3 panic,
4 ptr::NonNull,
5 task::{Context, Poll, Waker},
6};
7
8use super::utils::UnsafeCellExt;
9use crate::{
10 task::{
11 core::{Cell, Core, CoreStage, Header, Trailer},
12 state::Snapshot,
13 waker::waker_ref,
14 Schedule, Task,
15 },
16 utils::thread_id::{try_get_current_thread_id, DEFAULT_THREAD_ID},
17};
18
19pub(crate) struct Harness<T: Future, S: 'static> {
20 cell: NonNull<Cell<T, S>>,
21}
22
23impl<T, S> Harness<T, S>
24where
25 T: Future,
26 S: 'static,
27{
28 pub(crate) unsafe fn from_raw(ptr: NonNull<Header>) -> Harness<T, S> {
29 Harness {
30 cell: ptr.cast::<Cell<T, S>>(),
31 }
32 }
33
34 fn header(&self) -> &Header {
35 unsafe { &self.cell.as_ref().header }
36 }
37
38 fn trailer(&self) -> &Trailer {
39 unsafe { &self.cell.as_ref().trailer }
40 }
41
42 fn core(&self) -> &Core<T, S> {
43 unsafe { &self.cell.as_ref().core }
44 }
45}
46
47impl<T, S> Harness<T, S>
48where
49 T: Future,
50 S: Schedule,
51{
52 pub(super) fn poll(self) {
54 trace!("MONOIO DEBUG[Harness]:: poll");
55 match self.poll_inner() {
56 PollFuture::Notified => {
57 self.header().state.ref_inc();
59 self.core().scheduler.yield_now(self.get_new_task());
60 }
61 PollFuture::Complete => {
62 self.complete();
63 }
64 PollFuture::Done => (),
65 }
66 }
67
68 fn poll_inner(&self) -> PollFuture {
73 self.header().state.transition_to_running();
75
76 let waker_ref = waker_ref::<T, S>(self.header());
78 let cx = Context::from_waker(&waker_ref);
79 let res = poll_future(&self.core().stage, cx);
80
81 if res == Poll::Ready(()) {
82 return PollFuture::Complete;
83 }
84
85 use super::state::TransitionToIdle;
86 match self.header().state.transition_to_idle() {
87 TransitionToIdle::Ok => PollFuture::Done,
88 TransitionToIdle::OkNotified => PollFuture::Notified,
89 }
90 }
91
92 pub(super) fn dealloc(self) {
93 trace!("MONOIO DEBUG[Harness]:: dealloc");
94
95 self.trailer().waker.with_mut(drop);
97
98 self.core().stage.with_mut(drop);
100
101 unsafe {
102 drop(Box::from_raw(self.cell.as_ptr()));
103 }
104 }
105
106 #[cfg(feature = "sync")]
107 pub(super) fn finish(self, val: <T as Future>::Output) {
108 trace!("MONOIO DEBUG[Harness]:: finish");
109 self.header().state.transition_to_running();
110 self.core().stage.store_output(val);
111 self.complete();
112 }
113
114 pub(super) fn try_read_output(self, dst: &mut Poll<T::Output>, waker: &Waker) {
118 trace!("MONOIO DEBUG[Harness]:: try_read_output");
119 if can_read_output(self.header(), self.trailer(), waker) {
120 *dst = Poll::Ready(self.core().stage.take_output());
121 }
122 }
123
124 pub(super) fn drop_join_handle_slow(self) {
125 trace!("MONOIO DEBUG[Harness]:: drop_join_handle_slow");
126
127 let mut maybe_panic = None;
128
129 if self.header().state.unset_join_interested().is_err() {
132 let panic = panic::catch_unwind(panic::AssertUnwindSafe(|| {
138 self.core().stage.drop_future_or_output();
139 }));
140
141 if let Err(panic) = panic {
142 maybe_panic = Some(panic);
143 }
144 }
145
146 self.drop_reference();
148
149 if let Some(panic) = maybe_panic {
150 panic::resume_unwind(panic);
151 }
152 }
153
154 pub(super) fn wake_by_val(self) {
162 trace!("MONOIO DEBUG[Harness]:: wake_by_val");
163 let owner_id = self.header().owner_id;
164 if is_remote_task(owner_id) {
165 if self.header().state.transition_to_notified_without_submit() {
166 self.drop_reference();
167 return;
168 }
169 trace!("MONOIO DEBUG[Harness]:: wake_by_val with another thread id");
171 #[cfg(feature = "sync")]
172 {
173 use crate::task::waker::raw_waker;
174 let waker = raw_waker::<T, S>(self.cell.cast::<Header>().as_ptr());
175 let waker = unsafe { Waker::from_raw(waker) };
177 crate::runtime::CURRENT.try_with(|maybe_ctx| match maybe_ctx {
178 Some(ctx) => {
179 ctx.send_waker(owner_id, waker);
180 ctx.unpark_thread(owner_id);
181 }
182 None => {
183 let _ = crate::runtime::DEFAULT_CTX.try_with(|default_ctx| {
184 crate::runtime::CURRENT.set(default_ctx, || {
185 crate::runtime::CURRENT.with(|ctx| {
186 ctx.send_waker(owner_id, waker);
187 ctx.unpark_thread(owner_id);
188 });
189 });
190 });
191 }
192 });
193 return;
194 }
195 #[cfg(not(feature = "sync"))]
196 {
197 panic!("waker can only be sent across threads when `sync` feature enabled");
198 }
199 }
200
201 use super::state::TransitionToNotified;
202 match self.header().state.transition_to_notified() {
203 TransitionToNotified::Submit => {
204 self.core().scheduler.schedule(self.get_new_task());
206 }
207 TransitionToNotified::DoNothing => {
208 self.drop_reference();
210 }
211 }
212 }
213
214 pub(super) fn wake_by_ref(&self) {
218 trace!("MONOIO DEBUG[Harness]:: wake_by_ref");
219 let owner_id = self.header().owner_id;
220 if is_remote_task(owner_id) {
221 if self.header().state.transition_to_notified_without_submit() {
222 return;
223 }
224
225 trace!("MONOIO DEBUG[Harness]:: wake_by_ref with another thread id");
227 #[cfg(feature = "sync")]
228 {
229 use crate::task::waker::raw_waker;
230 let waker = raw_waker::<T, S>(self.cell.cast::<Header>().as_ptr());
231 let waker = unsafe { Waker::from_raw(waker) };
233 self.header().state.ref_inc();
234 crate::runtime::CURRENT.try_with(|maybe_ctx| match maybe_ctx {
235 Some(ctx) => {
236 ctx.send_waker(owner_id, waker);
237 ctx.unpark_thread(owner_id);
238 }
239 None => {
240 let _ = crate::runtime::DEFAULT_CTX.try_with(|default_ctx| {
241 crate::runtime::CURRENT.set(default_ctx, || {
242 crate::runtime::CURRENT.with(|ctx| {
243 ctx.send_waker(owner_id, waker);
244 ctx.unpark_thread(owner_id);
245 });
246 });
247 });
248 }
249 });
250 return;
251 }
252 #[cfg(not(feature = "sync"))]
253 {
254 panic!("waker can only be sent across threads when `sync` feature enabled");
255 }
256 }
257
258 use super::state::TransitionToNotified;
259 match self.header().state.transition_to_notified() {
260 TransitionToNotified::Submit => {
261 self.header().state.ref_inc();
263 self.core().scheduler.schedule(self.get_new_task());
264 }
265 TransitionToNotified::DoNothing => (),
266 }
267 }
268
269 pub(super) fn drop_reference(self) {
270 trace!("MONOIO DEBUG[Harness]:: drop_reference");
271 if self.header().state.ref_dec() {
272 self.dealloc();
273 }
274 }
275
276 fn complete(self) {
280 let snapshot = self.header().state.transition_to_complete();
284
285 let _ = panic::catch_unwind(panic::AssertUnwindSafe(|| {
288 if !snapshot.is_join_interested() {
289 self.core().stage.drop_future_or_output();
293 } else if snapshot.has_join_waker() {
294 self.trailer().wake_join();
297 }
298 }));
299 }
300
301 fn get_new_task(&self) -> Task<S> {
310 unsafe { Task::from_raw(self.cell.cast()) }
313 }
314}
315
316fn is_remote_task(owner_id: usize) -> bool {
317 if owner_id == DEFAULT_THREAD_ID {
318 return true;
319 }
320 match try_get_current_thread_id() {
321 Some(tid) => owner_id != tid,
322 None => true,
323 }
324}
325
326fn can_read_output(header: &Header, trailer: &Trailer, waker: &Waker) -> bool {
327 let snapshot = header.state.load();
329
330 debug_assert!(snapshot.is_join_interested());
331
332 if !snapshot.is_complete() {
333 let res = if snapshot.has_join_waker() {
335 let will_wake = unsafe {
339 trailer.will_wake(waker)
342 };
343
344 if will_wake {
345 return false;
348 }
349
350 header
359 .state
360 .unset_waker()
361 .and_then(|snapshot| set_join_waker(header, trailer, waker.clone(), snapshot))
362 } else {
363 set_join_waker(header, trailer, waker.clone(), snapshot)
364 };
365
366 match res {
367 Ok(_) => return false,
368 Err(snapshot) => {
369 assert!(snapshot.is_complete());
370 }
371 }
372 }
373 true
374}
375
376fn set_join_waker(
377 header: &Header,
378 trailer: &Trailer,
379 waker: Waker,
380 snapshot: Snapshot,
381) -> Result<Snapshot, Snapshot> {
382 assert!(snapshot.is_join_interested());
383 assert!(!snapshot.has_join_waker());
384
385 unsafe {
388 trailer.set_waker(Some(waker));
389 }
390
391 let res = header.state.set_join_waker();
393
394 if res.is_err() {
396 unsafe {
397 trailer.set_waker(None);
398 }
399 }
400
401 res
402}
403
404enum PollFuture {
405 Complete,
406 Notified,
407 Done,
408}
409
410fn poll_future<T: Future>(core: &CoreStage<T>, cx: Context<'_>) -> Poll<()> {
413 let output = core.poll(cx);
433
434 let output = match output {
436 Poll::Pending => return Poll::Pending,
440 Poll::Ready(output) => output,
441 };
442
443 core.store_output(output);
448
449 Poll::Ready(())
450}