Skip to main content

compio_executor/
join_handle.rs

1use std::{
2    error::Error,
3    fmt::Display,
4    marker::PhantomData,
5    mem::ManuallyDrop,
6    panic::resume_unwind,
7    pin::Pin,
8    ptr,
9    task::{Context, Poll},
10};
11
12use compio_log::{instrument, trace};
13
14use crate::{Panic, task::Task};
15
16/// A handle that awaits the result of a task.
17///
18/// Dropping a [`JoinHandle`] will cancel the task. To run the task in the
19/// background, use [`JoinHandle::detach`].
20#[must_use = "Drop `JoinHandle` will cancel the task. Use `detach` to run it in background."]
21#[derive(Debug)]
22#[repr(transparent)]
23pub struct JoinHandle<T> {
24    task: Option<Task>,
25    _marker: PhantomData<T>,
26}
27
28/// If T is send, we can poll result from other thread
29unsafe impl<T: Send> Send for JoinHandle<T> {}
30
31/// JoinHandle does not expose any &self interface, so it's unconditionally
32/// Sync.
33unsafe impl<T> Sync for JoinHandle<T> {}
34
35impl<T> Unpin for JoinHandle<T> {}
36
37impl<T> JoinHandle<T> {
38    pub(crate) fn new(task: Task) -> Self {
39        Self {
40            task: Some(task),
41            _marker: PhantomData,
42        }
43    }
44
45    /// Cancel the task and wait for the result, if any.
46    pub async fn cancel(self) -> Option<T> {
47        self.task.as_ref()?.cancel(false);
48        self.await.ok()
49    }
50
51    /// Detach the task to let it run in the background.
52    pub fn detach(self) {
53        unsafe { ptr::drop_in_place(&raw mut ManuallyDrop::new(self).task) };
54    }
55}
56
57/// Task failed to execute to completion.
58#[derive(Debug)]
59pub enum JoinError {
60    /// The task was cancelled.
61    Cancelled,
62    /// The task panicked.
63    Panicked(Panic),
64}
65
66/// Trait to resume unwind from a [`JoinError`].
67pub trait ResumeUnwind {
68    /// The output type.
69    type Output;
70
71    /// Resume the panic if the task panicked.
72    fn resume_unwind(self) -> Self::Output;
73}
74
75impl<T> ResumeUnwind for Result<T, JoinError> {
76    type Output = Option<T>;
77
78    fn resume_unwind(self) -> Self::Output {
79        match self {
80            Ok(res) => Some(res),
81            Err(JoinError::Cancelled) => None,
82            Err(JoinError::Panicked(e)) => resume_unwind(e),
83        }
84    }
85}
86
87impl JoinError {
88    /// Resume unwind if the task panicked, otherwise do nothing.
89    pub fn resume_unwind(self) {
90        if let JoinError::Panicked(e) = self {
91            resume_unwind(e)
92        }
93    }
94}
95
96impl Display for JoinError {
97    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
98        match self {
99            JoinError::Cancelled => write!(f, "Task was cancelled"),
100            JoinError::Panicked(_) => write!(f, "Task has panicked"),
101        }
102    }
103}
104
105impl Error for JoinError {}
106
107impl<T> Future for JoinHandle<T> {
108    type Output = Result<T, JoinError>;
109
110    fn poll(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> {
111        instrument!(compio_log::Level::TRACE, "JoinHandle::poll");
112
113        let task = self.task.as_ref().expect("Cannot poll after completion");
114
115        unsafe { task.poll(cx) }.map(|res| {
116            trace!("Poll ready");
117
118            self.task = None;
119
120            match res {
121                Some(Ok(res)) => Ok(res),
122                Some(Err(e)) => Err(JoinError::Panicked(e)),
123                None => Err(JoinError::Cancelled),
124            }
125        })
126    }
127}
128
129impl<T> Drop for JoinHandle<T> {
130    fn drop(&mut self) {
131        instrument!(compio_log::Level::TRACE, "JoinHandle::drop");
132
133        if let Some(task) = self.task.as_ref() {
134            task.cancel(true);
135        }
136    }
137}