1use std::{fmt, pin::Pin, task::Context, task::Poll, task::ready};
2
3use async_task::Task;
4
5#[derive(Debug)]
6pub struct JoinHandle<T> {
8 task: Option<Task<T>>,
9}
10
11impl<T> JoinHandle<T> {
12 pub(crate) fn new(task: Task<T>) -> Self {
13 JoinHandle { task: Some(task) }
14 }
15
16 pub fn cancel(mut self) {
18 if let Some(t) = self.task.take() {
19 drop(t.cancel());
20 }
21 }
22
23 pub fn detach(mut self) {
25 if let Some(t) = self.task.take() {
26 t.detach();
27 }
28 }
29
30 pub fn is_finished(&self) -> bool {
32 match &self.task {
33 Some(fut) => fut.is_finished(),
34 None => true,
35 }
36 }
37}
38
39impl<T> Drop for JoinHandle<T> {
40 fn drop(&mut self) {
41 if let Some(fut) = self.task.take() {
42 fut.detach();
43 }
44 }
45}
46
47impl<T> Future for JoinHandle<T> {
48 type Output = Result<T, JoinError>;
49
50 fn poll(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> {
51 Poll::Ready(match self.task.as_mut() {
52 Some(fut) => Ok(ready!(Pin::new(fut).poll(cx))),
53 None => Err(JoinError),
54 })
55 }
56}
57
58#[derive(Debug, Copy, Clone)]
59pub struct JoinError;
60
61impl fmt::Display for JoinError {
62 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
63 write!(f, "JoinError")
64 }
65}
66
67impl std::error::Error for JoinError {}