monoio/task/
harness.rs

1use std::{
2    future::Future,
3    panic,
4    ptr::NonNull,
5    task::{Context, Poll, Waker},
6};
7
8use super::utils::UnsafeCellExt;
9use crate::{
10    task::{
11        core::{Cell, Core, CoreStage, Header, Trailer},
12        state::Snapshot,
13        waker::waker_ref,
14        Schedule, Task,
15    },
16    utils::thread_id::{try_get_current_thread_id, DEFAULT_THREAD_ID},
17};
18
19pub(crate) struct Harness<T: Future, S: 'static> {
20    cell: NonNull<Cell<T, S>>,
21}
22
23impl<T, S> Harness<T, S>
24where
25    T: Future,
26    S: 'static,
27{
28    pub(crate) unsafe fn from_raw(ptr: NonNull<Header>) -> Harness<T, S> {
29        Harness {
30            cell: ptr.cast::<Cell<T, S>>(),
31        }
32    }
33
34    fn header(&self) -> &Header {
35        unsafe { &self.cell.as_ref().header }
36    }
37
38    fn trailer(&self) -> &Trailer {
39        unsafe { &self.cell.as_ref().trailer }
40    }
41
42    fn core(&self) -> &Core<T, S> {
43        unsafe { &self.cell.as_ref().core }
44    }
45}
46
47impl<T, S> Harness<T, S>
48where
49    T: Future,
50    S: Schedule,
51{
52    /// Polls the inner future.
53    pub(super) fn poll(self) {
54        trace!("MONOIO DEBUG[Harness]:: poll");
55        match self.poll_inner() {
56            PollFuture::Notified => {
57                // We should re-schedule the task.
58                self.header().state.ref_inc();
59                self.core().scheduler.yield_now(self.get_new_task());
60            }
61            PollFuture::Complete => {
62                self.complete();
63            }
64            PollFuture::Done => (),
65        }
66    }
67
68    /// Do polland return the status.
69    ///
70    /// poll_inner does not take a ref-count. We must make sure the task is
71    /// alive when call this method
72    fn poll_inner(&self) -> PollFuture {
73        // notified -> running
74        self.header().state.transition_to_running();
75
76        // poll the future
77        let waker_ref = waker_ref::<T, S>(self.header());
78        let cx = Context::from_waker(&waker_ref);
79        let res = poll_future(&self.core().stage, cx);
80
81        if res == Poll::Ready(()) {
82            return PollFuture::Complete;
83        }
84
85        use super::state::TransitionToIdle;
86        match self.header().state.transition_to_idle() {
87            TransitionToIdle::Ok => PollFuture::Done,
88            TransitionToIdle::OkNotified => PollFuture::Notified,
89        }
90    }
91
92    pub(super) fn dealloc(self) {
93        trace!("MONOIO DEBUG[Harness]:: dealloc");
94
95        // Release the join waker, if there is one.
96        self.trailer().waker.with_mut(drop);
97
98        // Check causality
99        self.core().stage.with_mut(drop);
100
101        unsafe {
102            drop(Box::from_raw(self.cell.as_ptr()));
103        }
104    }
105
106    #[cfg(feature = "sync")]
107    pub(super) fn finish(self, val: <T as Future>::Output) {
108        trace!("MONOIO DEBUG[Harness]:: finish");
109        self.header().state.transition_to_running();
110        self.core().stage.store_output(val);
111        self.complete();
112    }
113
114    // ===== join handle =====
115
116    /// Read the task output into `dst`.
117    pub(super) fn try_read_output(self, dst: &mut Poll<T::Output>, waker: &Waker) {
118        trace!("MONOIO DEBUG[Harness]:: try_read_output");
119        if can_read_output(self.header(), self.trailer(), waker) {
120            *dst = Poll::Ready(self.core().stage.take_output());
121        }
122    }
123
124    pub(super) fn drop_join_handle_slow(self) {
125        trace!("MONOIO DEBUG[Harness]:: drop_join_handle_slow");
126
127        let mut maybe_panic = None;
128
129        // Try to unset `JOIN_INTEREST`. This must be done as a first step in
130        // case the task concurrently completed.
131        if self.header().state.unset_join_interested().is_err() {
132            // It is our responsibility to drop the output. This is critical as
133            // the task output may not be `Send` and as such must remain with
134            // the scheduler or `JoinHandle`. i.e. if the output remains in the
135            // task structure until the task is deallocated, it may be dropped
136            // by a Waker on any arbitrary thread.
137            let panic = panic::catch_unwind(panic::AssertUnwindSafe(|| {
138                self.core().stage.drop_future_or_output();
139            }));
140
141            if let Err(panic) = panic {
142                maybe_panic = Some(panic);
143            }
144        }
145
146        // Drop the `JoinHandle` reference, possibly deallocating the task
147        self.drop_reference();
148
149        if let Some(panic) = maybe_panic {
150            panic::resume_unwind(panic);
151        }
152    }
153
154    // ===== waker behavior =====
155
156    /// This call consumes a ref-count and notifies the task. This will create a
157    /// new Notified and submit it if necessary.
158    ///
159    /// The caller does not need to hold a ref-count besides the one that was
160    /// passed to this call.
161    pub(super) fn wake_by_val(self) {
162        trace!("MONOIO DEBUG[Harness]:: wake_by_val");
163        let owner_id = self.header().owner_id;
164        if is_remote_task(owner_id) {
165            if self.header().state.transition_to_notified_without_submit() {
166                self.drop_reference();
167                return;
168            }
169            // send to target thread
170            trace!("MONOIO DEBUG[Harness]:: wake_by_val with another thread id");
171            #[cfg(feature = "sync")]
172            {
173                use crate::task::waker::raw_waker;
174                let waker = raw_waker::<T, S>(self.cell.cast::<Header>().as_ptr());
175                // # Ref Count: self -> waker
176                let waker = unsafe { Waker::from_raw(waker) };
177                crate::runtime::CURRENT.try_with(|maybe_ctx| match maybe_ctx {
178                    Some(ctx) => {
179                        ctx.send_waker(owner_id, waker);
180                        ctx.unpark_thread(owner_id);
181                    }
182                    None => {
183                        let _ = crate::runtime::DEFAULT_CTX.try_with(|default_ctx| {
184                            crate::runtime::CURRENT.set(default_ctx, || {
185                                crate::runtime::CURRENT.with(|ctx| {
186                                    ctx.send_waker(owner_id, waker);
187                                    ctx.unpark_thread(owner_id);
188                                });
189                            });
190                        });
191                    }
192                });
193                return;
194            }
195            #[cfg(not(feature = "sync"))]
196            {
197                panic!("waker can only be sent across threads when `sync` feature enabled");
198            }
199        }
200
201        use super::state::TransitionToNotified;
202        match self.header().state.transition_to_notified() {
203            TransitionToNotified::Submit => {
204                // # Ref Count: self -> task
205                self.core().scheduler.schedule(self.get_new_task());
206            }
207            TransitionToNotified::DoNothing => {
208                // # Ref Count: self -> -1
209                self.drop_reference();
210            }
211        }
212    }
213
214    /// This call notifies the task. It will not consume any ref-counts, but the
215    /// caller should hold a ref-count.  This will create a new Notified and
216    /// submit it if necessary.
217    pub(super) fn wake_by_ref(&self) {
218        trace!("MONOIO DEBUG[Harness]:: wake_by_ref");
219        let owner_id = self.header().owner_id;
220        if is_remote_task(owner_id) {
221            if self.header().state.transition_to_notified_without_submit() {
222                return;
223            }
224
225            // send to target thread
226            trace!("MONOIO DEBUG[Harness]:: wake_by_ref with another thread id");
227            #[cfg(feature = "sync")]
228            {
229                use crate::task::waker::raw_waker;
230                let waker = raw_waker::<T, S>(self.cell.cast::<Header>().as_ptr());
231                // We create a new waker so we need to inc ref count.
232                let waker = unsafe { Waker::from_raw(waker) };
233                self.header().state.ref_inc();
234                crate::runtime::CURRENT.try_with(|maybe_ctx| match maybe_ctx {
235                    Some(ctx) => {
236                        ctx.send_waker(owner_id, waker);
237                        ctx.unpark_thread(owner_id);
238                    }
239                    None => {
240                        let _ = crate::runtime::DEFAULT_CTX.try_with(|default_ctx| {
241                            crate::runtime::CURRENT.set(default_ctx, || {
242                                crate::runtime::CURRENT.with(|ctx| {
243                                    ctx.send_waker(owner_id, waker);
244                                    ctx.unpark_thread(owner_id);
245                                });
246                            });
247                        });
248                    }
249                });
250                return;
251            }
252            #[cfg(not(feature = "sync"))]
253            {
254                panic!("waker can only be sent across threads when `sync` feature enabled");
255            }
256        }
257
258        use super::state::TransitionToNotified;
259        match self.header().state.transition_to_notified() {
260            TransitionToNotified::Submit => {
261                // # Ref Count: +1 -> task
262                self.header().state.ref_inc();
263                self.core().scheduler.schedule(self.get_new_task());
264            }
265            TransitionToNotified::DoNothing => (),
266        }
267    }
268
269    pub(super) fn drop_reference(self) {
270        trace!("MONOIO DEBUG[Harness]:: drop_reference");
271        if self.header().state.ref_dec() {
272            self.dealloc();
273        }
274    }
275
276    // ====== internal ======
277
278    /// Complete the task. This method assumes that the state is RUNNING.
279    fn complete(self) {
280        // The future has completed and its output has been written to the task
281        // stage. We transition from running to complete.
282
283        let snapshot = self.header().state.transition_to_complete();
284
285        // We catch panics here in case dropping the future or waking the
286        // JoinHandle panics.
287        let _ = panic::catch_unwind(panic::AssertUnwindSafe(|| {
288            if !snapshot.is_join_interested() {
289                // The `JoinHandle` is not interested in the output of
290                // this task. It is our responsibility to drop the
291                // output.
292                self.core().stage.drop_future_or_output();
293            } else if snapshot.has_join_waker() {
294                // Notify the join handle. The previous transition obtains the
295                // lock on the waker cell.
296                self.trailer().wake_join();
297            }
298        }));
299    }
300
301    /// Create a new task that holds its own ref-count.
302    ///
303    /// # Safety
304    ///
305    /// Any use of `self` after this call must ensure that a ref-count to the
306    /// task holds the task alive until after the use of `self`. Passing the
307    /// returned Task to any method on `self` is unsound if dropping the Task
308    /// could drop `self` before the call on `self` returned.
309    fn get_new_task(&self) -> Task<S> {
310        // safety: The header is at the beginning of the cell, so this cast is
311        // safe.
312        unsafe { Task::from_raw(self.cell.cast()) }
313    }
314}
315
316fn is_remote_task(owner_id: usize) -> bool {
317    if owner_id == DEFAULT_THREAD_ID {
318        return true;
319    }
320    match try_get_current_thread_id() {
321        Some(tid) => owner_id != tid,
322        None => true,
323    }
324}
325
326fn can_read_output(header: &Header, trailer: &Trailer, waker: &Waker) -> bool {
327    // Load a snapshot of the current task state
328    let snapshot = header.state.load();
329
330    debug_assert!(snapshot.is_join_interested());
331
332    if !snapshot.is_complete() {
333        // The waker must be stored in the task struct.
334        let res = if snapshot.has_join_waker() {
335            // There already is a waker stored in the struct. If it matches
336            // the provided waker, then there is no further work to do.
337            // Otherwise, the waker must be swapped.
338            let will_wake = unsafe {
339                // Safety: when `JOIN_INTEREST` is set, only `JOIN_HANDLE`
340                // may mutate the `waker` field.
341                trailer.will_wake(waker)
342            };
343
344            if will_wake {
345                // The task is not complete **and** the waker is up to date,
346                // there is nothing further that needs to be done.
347                return false;
348            }
349
350            // Unset the `JOIN_WAKER` to gain mutable access to the `waker`
351            // field then update the field with the new join worker.
352            //
353            // This requires two atomic operations, unsetting the bit and
354            // then resetting it. If the task transitions to complete
355            // concurrently to either one of those operations, then setting
356            // the join waker fails and we proceed to reading the task
357            // output.
358            header
359                .state
360                .unset_waker()
361                .and_then(|snapshot| set_join_waker(header, trailer, waker.clone(), snapshot))
362        } else {
363            set_join_waker(header, trailer, waker.clone(), snapshot)
364        };
365
366        match res {
367            Ok(_) => return false,
368            Err(snapshot) => {
369                assert!(snapshot.is_complete());
370            }
371        }
372    }
373    true
374}
375
376fn set_join_waker(
377    header: &Header,
378    trailer: &Trailer,
379    waker: Waker,
380    snapshot: Snapshot,
381) -> Result<Snapshot, Snapshot> {
382    assert!(snapshot.is_join_interested());
383    assert!(!snapshot.has_join_waker());
384
385    // Safety: Only the `JoinHandle` may set the `waker` field. When
386    // `JOIN_INTEREST` is **not** set, nothing else will touch the field.
387    unsafe {
388        trailer.set_waker(Some(waker));
389    }
390
391    // Update the `JoinWaker` state accordingly
392    let res = header.state.set_join_waker();
393
394    // If the state could not be updated, then clear the join waker
395    if res.is_err() {
396        unsafe {
397            trailer.set_waker(None);
398        }
399    }
400
401    res
402}
403
404enum PollFuture {
405    Complete,
406    Notified,
407    Done,
408}
409
410/// Poll the future. If the future completes, the output is written to the
411/// stage field.
412fn poll_future<T: Future>(core: &CoreStage<T>, cx: Context<'_>) -> Poll<()> {
413    // CHIHAI: For efficiency we do not catch.
414
415    // Poll the future.
416    // let output = panic::catch_unwind(panic::AssertUnwindSafe(|| {
417    //     struct Guard<'a, T: Future> {
418    //         core: &'a CoreStage<T>,
419    //     }
420    //     impl<'a, T: Future> Drop for Guard<'a, T> {
421    //         fn drop(&mut self) {
422    //             // If the future panics on poll, we drop it inside the panic
423    //             // guard.
424    //             self.core.drop_future_or_output();
425    //         }
426    //     }
427    //     let guard = Guard { core };
428    //     let res = guard.core.poll(cx);
429    //     mem::forget(guard);
430    //     res
431    // }));
432    let output = core.poll(cx);
433
434    // Prepare output for being placed in the core stage.
435    let output = match output {
436        // Ok(Poll::Pending) => return Poll::Pending,
437        // Ok(Poll::Ready(output)) => Ok(output),
438        // Err(panic) => Err(JoinError::panic(panic)),
439        Poll::Pending => return Poll::Pending,
440        Poll::Ready(output) => output,
441    };
442
443    // Catch and ignore panics if the future panics on drop.
444    // let _ = panic::catch_unwind(panic::AssertUnwindSafe(|| {
445    //     core.store_output(output);
446    // }));
447    core.store_output(output);
448
449    Poll::Ready(())
450}