gstthreadshare/runtime/executor/
join.rs1use futures::prelude::*;
11
12use std::fmt;
13use std::future::Future;
14use std::pin::Pin;
15use std::task::Poll;
16
17use super::TaskId;
18use super::{context::Context, scheduler};
19
20#[derive(Debug)]
21pub struct JoinError(TaskId);
22
23impl fmt::Display for JoinError {
24 fn fmt(&self, fmt: &mut fmt::Formatter<'_>) -> fmt::Result {
25 write!(fmt, "{:?} was cancelled", self.0)
26 }
27}
28
29impl std::error::Error for JoinError {}
30
31pub struct JoinHandle<T> {
32 task: Option<async_task::Task<T>>,
33 task_id: TaskId,
34 scheduler: scheduler::ThrottlingHandle,
35}
36
37unsafe impl<T: Send> Send for JoinHandle<T> {}
38unsafe impl<T: Send> Sync for JoinHandle<T> {}
39
40impl<T> JoinHandle<T> {
41 pub(super) fn new(
42 task_id: TaskId,
43 task: async_task::Task<T>,
44 scheduler: &scheduler::ThrottlingHandle,
45 ) -> Self {
46 JoinHandle {
47 task: Some(task),
48 task_id,
49 scheduler: scheduler.clone(),
50 }
51 }
52
53 pub fn is_current(&self) -> bool {
54 match scheduler::Throttling::current().zip(TaskId::current()) {
55 Some((cur_scheduler, task_id)) => {
56 cur_scheduler == self.scheduler && task_id == self.task_id
57 }
58 _ => false,
59 }
60 }
61
62 pub fn context(&self) -> Context {
63 Context::from(self.scheduler.clone())
64 }
65
66 pub fn task_id(&self) -> TaskId {
67 self.task_id
68 }
69
70 pub fn cancel(mut self) {
71 let _ = self.task.take().map(|task| task.cancel());
72 }
73}
74
75impl<T> Future for JoinHandle<T> {
76 type Output = Result<T, JoinError>;
77
78 fn poll(mut self: Pin<&mut Self>, cx: &mut std::task::Context<'_>) -> Poll<Self::Output> {
79 if self.as_ref().is_current() {
80 panic!("Trying to join task {:?} from itself", self.as_ref());
81 }
82
83 if let Some(task) = self.as_mut().task.as_mut() {
84 task.poll_unpin(cx).map(Ok)
89 } else {
90 Poll::Ready(Err(JoinError(self.task_id)))
91 }
92 }
93}
94
95impl<T> Drop for JoinHandle<T> {
96 fn drop(&mut self) {
97 if let Some(task) = self.task.take() {
98 task.detach();
99 }
100 }
101}
102
103impl<T> fmt::Debug for JoinHandle<T> {
104 fn fmt(&self, fmt: &mut fmt::Formatter<'_>) -> fmt::Result {
105 fmt.debug_struct("JoinHandle")
106 .field("context", &self.scheduler.context_name())
107 .field("task_id", &self.task_id)
108 .finish()
109 }
110}