Skip to main content

nio_task/
join.rs

1use crate::{AbortHandle, raw::RawTask};
2
3use super::{COMPLETE, error::JoinError, id::TaskId};
4use std::{
5    fmt,
6    future::Future,
7    marker::PhantomData,
8    pin::Pin,
9    task::{Context, Poll},
10};
11
12pub struct JoinHandle<T> {
13    raw: RawTask,
14    _p: PhantomData<T>,
15}
16unsafe impl<T: Send> Send for JoinHandle<T> {}
17unsafe impl<T: Send> Sync for JoinHandle<T> {}
18
19impl<T> std::panic::UnwindSafe for JoinHandle<T> {}
20impl<T> std::panic::RefUnwindSafe for JoinHandle<T> {}
21
22impl<T> JoinHandle<T> {
23    pub(super) fn new(raw: RawTask) -> JoinHandle<T> {
24        JoinHandle {
25            raw,
26            _p: PhantomData,
27        }
28    }
29
30    #[inline]
31    pub fn abort_handle(&self) -> AbortHandle {
32        AbortHandle {
33            raw: self.raw.clone(),
34        }
35    }
36
37    #[inline]
38    pub fn abort(&self) {
39        self.raw.abort_task();
40    }
41
42    #[inline]
43    pub fn is_finished(&self) -> bool {
44        self.raw.header().state.load().has(COMPLETE)
45    }
46
47    #[inline]
48    pub fn id(&self) -> TaskId {
49        TaskId::new(&self.raw)
50    }
51}
52
53impl<T> Future for JoinHandle<T> {
54    type Output = Result<T, JoinError>;
55
56    fn poll(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> {
57        let mut ret: Poll<Result<T, JoinError>> = Poll::Pending;
58        unsafe {
59            self.raw
60                .read_output(&mut ret as *mut _ as *mut (), cx.waker());
61        }
62        ret
63    }
64}
65
66impl<T> Drop for JoinHandle<T> {
67    fn drop(&mut self) {
68        let header = self.raw.header();
69
70        let is_task_complete = header.state.unset_waker_and_interested();
71        if is_task_complete {
72            // If the task is complete then waker is droped by the executor.
73            // We just only need to drop the output.
74            unsafe { self.raw.drop_output_from_join_handler() };
75        } else {
76            unsafe { *header.join_waker.get() = None };
77        }
78    }
79}
80
81impl<T> fmt::Debug for JoinHandle<T> {
82    fn fmt(&self, fmt: &mut fmt::Formatter<'_>) -> fmt::Result {
83        fmt.debug_struct("JoinHandle")
84            .field("id", &self.id())
85            .finish()
86    }
87}