1use crate::AsyncStdGlobalRuntime;
3use arta::task::{RuntimeJoinHandle, TaskRuntime};
4use futures::{prelude::Future, FutureExt};
5use std::{
6 panic::AssertUnwindSafe,
7 pin::Pin,
8 task::{Context, Poll},
9};
10
11pub struct AsyncStdJoinHandle<T> {
13 inner: async_std::task::JoinHandle<std::thread::Result<T>>,
14}
15
16impl<T> Future for AsyncStdJoinHandle<T>
17where
18 T: Send,
19{
20 type Output = std::thread::Result<T>;
21
22 fn poll(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> {
23 self.inner.poll_unpin(cx)
24 }
25}
26
27impl<T> RuntimeJoinHandle<T> for AsyncStdJoinHandle<T>
28where
29 T: Send + 'static,
30{
31 fn cancel(self) -> impl Future<Output = Option<std::thread::Result<T>>> + Send {
32 self.inner.cancel()
33 }
34}
35
36impl TaskRuntime for AsyncStdGlobalRuntime {
37 type JoinHandle<T> = AsyncStdJoinHandle<T> where T: Send + 'static;
38
39 fn spawn<R>(&self, future: impl Future<Output = R> + Send + 'static) -> Self::JoinHandle<R>
40 where
41 R: Send + 'static,
42 {
43 AsyncStdJoinHandle {
44 inner: async_std::task::spawn(AssertUnwindSafe(future).catch_unwind()),
45 }
46 }
47
48 fn spawn_blocking<R>(&self, task: impl FnOnce() -> R + Send + 'static) -> Self::JoinHandle<R>
49 where
50 R: Send + 'static,
51 {
52 AsyncStdJoinHandle {
53 inner: async_std::task::spawn_blocking(|| {
54 std::panic::catch_unwind(AssertUnwindSafe(task))
55 }),
56 }
57 }
58}