1use futures::channel::oneshot;
2use futures::channel::oneshot::{Receiver, Sender};
3use std::boxed::Box;
4use std::cell::{Cell, RefCell};
5use std::future::Future;
6use std::ops::{Deref, DerefMut};
7use std::pin::Pin;
8use std::rc::Rc;
9use std::task::Poll;
10
11pub fn create_task<F>(future: F) -> (Task, JoinHandle<F::Output>)
12where
13 F: Future + 'static,
14{
15 let (output_tx, output_rx) = oneshot::channel::<F::Output>();
16 let abort = Rc::new(Cell::new(false));
17
18 (
19 Task::from(GenericTask {
20 future: Box::pin(future),
21 output_tx: Some(output_tx),
22 abort: Rc::clone(&abort),
23 }),
24 JoinHandle(RefCell::new(JoinHandleInner::Pending {
25 output_rx: Box::pin(output_rx),
26 abort,
27 })),
28 )
29}
30
31pub struct Task(Pin<Box<dyn Future<Output = ()>>>);
32
33impl Deref for Task {
34 type Target = Pin<Box<dyn Future<Output = ()>>>;
35
36 fn deref(&self) -> &Self::Target {
37 &self.0
38 }
39}
40
41impl DerefMut for Task {
42 fn deref_mut(&mut self) -> &mut Self::Target {
43 &mut self.0
44 }
45}
46
47#[cfg(test)]
48impl Future for Task {
49 type Output = ();
50
51 fn poll(mut self: Pin<&mut Self>, cx: &mut std::task::Context<'_>) -> Poll<Self::Output> {
52 Future::poll(self.deref_mut().as_mut(), cx)
53 }
54}
55
56impl<F> From<GenericTask<F>> for Task
57where
58 F: Future + 'static,
59{
60 fn from(generic_task: GenericTask<F>) -> Self {
61 Self(Box::pin(generic_task))
62 }
63}
64
65struct GenericTask<F>
66where
67 F: Future + 'static,
68{
69 future: Pin<Box<F>>,
70 output_tx: Option<Sender<F::Output>>,
72 abort: Rc<Cell<bool>>,
73}
74
75impl<F> Future for GenericTask<F>
76where
77 F: Future + 'static,
78{
79 type Output = ();
80
81 fn poll(mut self: Pin<&mut Self>, cx: &mut std::task::Context<'_>) -> Poll<Self::Output> {
82 if self.abort.get() {
83 Poll::Ready(())
84 } else {
85 match Future::poll(self.future.as_mut(), cx) {
86 Poll::Ready(value) => {
87 let _ = self.output_tx.take().unwrap().send(value);
88 Poll::Ready(())
89 }
90
91 Poll::Pending => Poll::Pending,
92 }
93 }
94 }
95}
96
97pub struct JoinHandle<T>(RefCell<JoinHandleInner<T>>);
116
117enum JoinHandleInner<T> {
118 Pending {
119 output_rx: Pin<Box<Receiver<T>>>,
120 abort: Rc<Cell<bool>>,
121 },
122 Finished(
123 Option<T>,
126 ),
127 Aborted,
128}
129
130impl<T> JoinHandle<T> {
131 fn poll(&self) {
132 let mut inner = self.0.borrow_mut();
133
134 if let JoinHandleInner::Pending {
135 output_rx,
136 abort: _,
137 } = &mut *inner
138 {
139 match output_rx.try_recv() {
140 Ok(Some(value)) => *inner = JoinHandleInner::Finished(Some(value)),
141 Ok(None) => { }
142 Err(_) => *inner = JoinHandleInner::Aborted,
143 }
144 }
145 }
146
147 pub fn abort(&self) {
149 let mut inner = self.0.borrow_mut();
150
151 if let JoinHandleInner::Pending {
152 output_rx: _,
153 abort,
154 } = &*inner
155 {
156 abort.set(true);
157 *inner = JoinHandleInner::Aborted;
158 }
159 }
160
161 pub fn is_finished(&self) -> bool {
163 self.poll();
164 matches!(&*self.0.borrow(), JoinHandleInner::Finished(_))
165 }
166
167 pub fn is_aborted(&self) -> bool {
170 self.poll();
171 matches!(&*self.0.borrow(), JoinHandleInner::Aborted)
172 }
173}
174
175impl<T> Future for JoinHandle<T> {
176 type Output = Option<T>;
177
178 fn poll(self: Pin<&mut Self>, cx: &mut std::task::Context<'_>) -> Poll<Self::Output> {
179 let mut inner = self.0.borrow_mut();
180
181 match &mut *inner {
182 JoinHandleInner::Pending {
183 output_rx,
184 abort: _,
185 } => match Future::poll(output_rx.as_mut(), cx) {
186 Poll::Ready(Ok(value)) => {
187 *inner = JoinHandleInner::Finished(None);
188 Poll::Ready(Some(value))
189 }
190
191 Poll::Ready(Err(_)) => {
192 *inner = JoinHandleInner::Aborted;
193 Poll::Ready(None)
194 }
195
196 Poll::Pending => Poll::Pending,
197 },
198
199 JoinHandleInner::Finished(value) => Poll::Ready(value.take()),
200 JoinHandleInner::Aborted => Poll::Ready(None),
201 }
202 }
203}
204
205#[cfg(test)]
206#[tokio::test]
207async fn test() {
208 use std::time::Duration;
209 use tokio::task::LocalSet;
210 use tokio::time;
211
212 let local_set = LocalSet::new();
213
214 local_set
215 .run_until(async {
216 let (task, join_handle) = create_task(async {
217 time::sleep(Duration::from_millis(50)).await;
218 "test"
219 });
220 tokio::task::spawn_local(task);
221 assert!(!join_handle.is_finished());
222 assert!(!join_handle.is_aborted());
223 assert_eq!(join_handle.await, Some("test"));
224
225 let (task, join_handle) = create_task(async {
228 time::sleep(Duration::from_millis(50)).await;
229 "test"
230 });
231 tokio::task::spawn_local(task);
232 time::sleep(Duration::from_millis(100)).await;
233 assert!(join_handle.is_finished());
234 assert!(!join_handle.is_aborted());
235 join_handle.abort();
236 assert!(join_handle.is_finished());
237 assert!(!join_handle.is_aborted());
238 assert_eq!(join_handle.await, Some("test"));
239
240 let (task, join_handle) = create_task(async {
243 time::sleep(Duration::from_millis(50)).await;
244 "test"
245 });
246 tokio::task::spawn_local(task);
247 assert!(!join_handle.is_finished());
248 assert!(!join_handle.is_aborted());
249 join_handle.abort();
250 assert!(!join_handle.is_finished());
251 assert!(join_handle.is_aborted());
252 assert_eq!(join_handle.await, None);
253
254 let (task, join_handle) = create_task(async {
257 time::sleep(Duration::from_millis(500)).await;
258 "test"
259 });
260 let tokio_join_handle = tokio::task::spawn_local(task);
261 assert!(!join_handle.is_finished());
262 assert!(!join_handle.is_aborted());
263 tokio_join_handle.abort();
264 time::sleep(Duration::from_millis(100)).await;
265 assert!(!join_handle.is_finished());
266 assert!(join_handle.is_aborted());
267 assert_eq!(join_handle.await, None);
268
269 let value = Rc::new(Cell::new(0i32));
272 let (task, join_handle) = create_task({
273 let value = Rc::clone(&value);
274 async move {
275 time::sleep(Duration::from_millis(50)).await;
276 value.set(1);
277 "test"
278 }
279 });
280 tokio::task::spawn_local(task);
281 assert!(!join_handle.is_finished());
282 assert!(!join_handle.is_aborted());
283 drop(join_handle);
284 assert_eq!(value.get(), 0);
285 time::sleep(Duration::from_millis(100)).await;
286 assert_eq!(value.get(), 1);
287 })
288 .await;
289
290 local_set.await;
291}