async_utility/task/
mod.rs

1// Copyright (c) 2022-2023 Yuki Kishimoto
2// Distributed under the MIT software license
3
4//! Task
5
6use core::fmt;
7
8use futures_util::stream::{AbortHandle, Abortable};
9use futures_util::Future;
10#[cfg(not(target_arch = "wasm32"))]
11use tokio::task::JoinHandle as TokioJoinHandle;
12
13#[cfg(target_arch = "wasm32")]
14mod wasm;
15
16#[cfg(not(target_arch = "wasm32"))]
17use crate::runtime;
18
19/// Task error
20#[derive(Debug)]
21pub enum Error {
22    /// Join Error
23    JoinError,
24}
25
26impl std::error::Error for Error {}
27
28impl fmt::Display for Error {
29    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
30        match self {
31            Self::JoinError => write!(f, "impossible to join thread"),
32        }
33    }
34}
35
36/// Join Handle
37pub enum JoinHandle<T> {
38    /// Tokio
39    #[cfg(not(target_arch = "wasm32"))]
40    Tokio(TokioJoinHandle<T>),
41    /// Wasm
42    #[cfg(target_arch = "wasm32")]
43    Wasm(self::wasm::JoinHandle<T>),
44}
45
46impl<T> JoinHandle<T> {
47    /// Join
48    pub async fn join(self) -> Result<T, Error> {
49        match self {
50            #[cfg(not(target_arch = "wasm32"))]
51            Self::Tokio(handle) => handle.await.map_err(|_| Error::JoinError),
52            #[cfg(target_arch = "wasm32")]
53            Self::Wasm(handle) => handle.join().await.map_err(|_| Error::JoinError),
54        }
55    }
56}
57
58/// Spawn new thread
59#[inline]
60#[cfg(not(target_arch = "wasm32"))]
61pub fn spawn<T>(future: T) -> JoinHandle<T::Output>
62where
63    T: Future + Send + 'static,
64    T::Output: Send + 'static,
65{
66    JoinHandle::Tokio(runtime::handle().spawn(future))
67}
68
69/// Spawn a new thread
70#[cfg(target_arch = "wasm32")]
71pub fn spawn<T>(future: T) -> JoinHandle<T::Output>
72where
73    T: Future + 'static,
74{
75    let handle = self::wasm::spawn(future);
76    JoinHandle::Wasm(handle)
77}
78
79/// Spawn abortable thread
80#[cfg(not(target_arch = "wasm32"))]
81pub fn abortable<T>(future: T) -> AbortHandle
82where
83    T: Future + Send + 'static,
84    T::Output: Send + 'static,
85{
86    let (abort_handle, abort_registration) = AbortHandle::new_pair();
87    let _ = spawn(Abortable::new(future, abort_registration));
88    abort_handle
89}
90
91/// Spawn abortable thread
92#[cfg(target_arch = "wasm32")]
93pub fn abortable<T>(future: T) -> AbortHandle
94where
95    T: Future + 'static,
96{
97    let (abort_handle, abort_registration) = AbortHandle::new_pair();
98    let _ = spawn(Abortable::new(future, abort_registration));
99    abort_handle
100}
101
102#[inline]
103#[cfg(not(target_arch = "wasm32"))]
104pub fn spawn_blocking<F, R>(f: F) -> TokioJoinHandle<R>
105where
106    F: FnOnce() -> R + Send + 'static,
107    R: Send + 'static,
108{
109    runtime::handle().spawn_blocking(f)
110}
111
112#[cfg(test)]
113mod tests {
114    use std::time::Duration;
115
116    use super::*;
117    use crate::time;
118
119    // TODO: test also wasm
120
121    #[tokio::test]
122    #[cfg(not(target_arch = "wasm32"))]
123    async fn test_is_tokio_context_macros() {
124        assert!(runtime::is_tokio_context());
125    }
126
127    #[async_std::test]
128    #[cfg(not(target_arch = "wasm32"))]
129    async fn test_is_tokio_context_in_async_std() {
130        let handle = runtime::handle();
131        let _guard = handle.enter();
132        assert!(runtime::is_tokio_context());
133    }
134
135    #[test]
136    #[cfg(not(target_arch = "wasm32"))]
137    fn test_is_tokio_context_once_lock() {
138        let handle = runtime::handle();
139        let _guard = handle.enter();
140        assert!(runtime::is_tokio_context());
141    }
142
143    #[tokio::test]
144    #[cfg(not(target_arch = "wasm32"))]
145    async fn test_spawn() {
146        let future = async {
147            time::sleep(Duration::from_secs(5)).await;
148            42
149        };
150        let handle = spawn(future);
151        let result = handle.join().await.unwrap();
152        assert_eq!(result, 42);
153    }
154
155    #[async_std::test]
156    #[cfg(not(target_arch = "wasm32"))]
157    async fn test_spawn_in_async_std() {
158        let future = async {
159            time::sleep(Duration::from_secs(5)).await;
160            42
161        };
162        let handle = spawn(future);
163        let result = handle.join().await.unwrap();
164        assert_eq!(result, 42);
165    }
166
167    #[test]
168    #[cfg(not(target_arch = "wasm32"))]
169    fn test_spawn_in_smol() {
170        smol::block_on(async {
171            let future = async {
172                time::sleep(Duration::from_secs(5)).await;
173                42
174            };
175            let handle = spawn(future);
176            let result = handle.join().await.unwrap();
177            assert_eq!(result, 42);
178        });
179    }
180
181    #[test]
182    #[cfg(not(target_arch = "wasm32"))]
183    fn test_spawn_outside_tokio_ctx() {
184        let future = async {
185            time::sleep(Duration::from_secs(5)).await;
186            42
187        };
188        let _handle = spawn(future);
189    }
190
191    #[tokio::test]
192    #[cfg(not(target_arch = "wasm32"))]
193    async fn test_spawn_blocking() {
194        let handle = spawn_blocking(|| 42);
195        let result = handle.await.unwrap();
196        assert_eq!(result, 42);
197    }
198
199    #[test]
200    #[cfg(not(target_arch = "wasm32"))]
201    fn test_spawn_blocking_outside_tokio_ctx() {
202        let _handle = spawn_blocking(|| 42);
203    }
204
205    #[tokio::test]
206    #[cfg(not(target_arch = "wasm32"))]
207    async fn test_abortable() {
208        let future = async {
209            time::sleep(Duration::from_secs(1)).await;
210            42
211        };
212        let abort_handle = abortable(future);
213        abort_handle.abort();
214        assert!(abort_handle.is_aborted());
215    }
216}