gstthreadshare/runtime/executor/
join.rs1use futures::prelude::*;
11
12use std::fmt;
13use std::future::Future;
14use std::pin::Pin;
15use std::task::Poll;
16
17use super::TaskId;
18use super::{context::Context, scheduler};
19
20#[derive(Debug)]
21pub struct JoinError(TaskId);
22
23impl fmt::Display for JoinError {
24 fn fmt(&self, fmt: &mut fmt::Formatter<'_>) -> fmt::Result {
25 write!(fmt, "{:?} was cancelled", self.0)
26 }
27}
28
29impl std::error::Error for JoinError {}
30
31pub struct JoinHandle<T> {
32 task: Option<async_task::Task<T>>,
33 task_id: TaskId,
34 scheduler: scheduler::ThrottlingHandle,
35}
36
37unsafe impl<T: Send> Send for JoinHandle<T> {}
38unsafe impl<T: Send> Sync for JoinHandle<T> {}
39
40impl<T> JoinHandle<T> {
41 pub(super) fn new(
42 task_id: TaskId,
43 task: async_task::Task<T>,
44 scheduler: &scheduler::ThrottlingHandle,
45 ) -> Self {
46 JoinHandle {
47 task: Some(task),
48 task_id,
49 scheduler: scheduler.clone(),
50 }
51 }
52
53 pub fn is_current(&self) -> bool {
54 if let Some((cur_scheduler, task_id)) =
55 scheduler::Throttling::current().zip(TaskId::current())
56 {
57 cur_scheduler == self.scheduler && task_id == self.task_id
58 } else {
59 false
60 }
61 }
62
63 pub fn context(&self) -> Context {
64 Context::from(self.scheduler.clone())
65 }
66
67 pub fn task_id(&self) -> TaskId {
68 self.task_id
69 }
70
71 pub fn cancel(mut self) {
72 let _ = self.task.take().map(|task| task.cancel());
73 }
74}
75
76impl<T> Future for JoinHandle<T> {
77 type Output = Result<T, JoinError>;
78
79 fn poll(mut self: Pin<&mut Self>, cx: &mut std::task::Context<'_>) -> Poll<Self::Output> {
80 if self.as_ref().is_current() {
81 panic!("Trying to join task {:?} from itself", self.as_ref());
82 }
83
84 if let Some(task) = self.as_mut().task.as_mut() {
85 task.poll_unpin(cx).map(Ok)
90 } else {
91 Poll::Ready(Err(JoinError(self.task_id)))
92 }
93 }
94}
95
96impl<T> Drop for JoinHandle<T> {
97 fn drop(&mut self) {
98 if let Some(task) = self.task.take() {
99 task.detach();
100 }
101 }
102}
103
104impl<T> fmt::Debug for JoinHandle<T> {
105 fn fmt(&self, fmt: &mut fmt::Formatter<'_>) -> fmt::Result {
106 fmt.debug_struct("JoinHandle")
107 .field("context", &self.scheduler.context_name())
108 .field("task_id", &self.task_id)
109 .finish()
110 }
111}