async_rt/
lib.rs

1pub mod arc;
2pub mod global;
3pub mod rt;
4pub mod task;
5pub mod tracker;
6
7#[cfg(feature = "either")]
8mod either;
9
10use std::fmt::{Debug, Formatter};
11
12use futures::channel::mpsc::{Receiver, UnboundedReceiver};
13use futures::future::{AbortHandle, Aborted};
14use futures::SinkExt;
15use std::future::Future;
16use std::pin::Pin;
17use std::sync::Arc;
18use std::task::{Context, Poll};
19
20#[cfg(all(
21    not(feature = "threadpool"),
22    not(feature = "tokio"),
23    not(target_arch = "wasm32")
24))]
25compile_error!(
26    "At least one runtime (i.e 'tokio', 'threadpool', 'wasm-bindgen-futures') must be enabled"
27);
28
29/// An owned permission to join on a task (await its termination).
30///
31/// This can be seen as an equivalent to [`std::thread::JoinHandle`] but for [`Future`] tasks rather than a thread.
32/// Note that the task associated with this `JoinHandle` will start running at the time [`Executor::spawn`] is called as
33/// well as according to the implemented runtime (i.e. [`tokio`]), even if `JoinHandle` has not been awaited.
34///
35/// Dropping `JoinHandle` will not abort or cancel the task. In other words, the task will continue to run in the background
36/// and any return value will be lost.
37///
38/// This `struct` is created by the [`Executor::spawn`].
39pub struct JoinHandle<T> {
40    inner: InnerJoinHandle<T>,
41}
42
43impl<T> Debug for JoinHandle<T> {
44    fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result {
45        f.debug_struct("JoinHandle").finish()
46    }
47}
48
49enum InnerJoinHandle<T> {
50    #[cfg(all(feature = "tokio", not(target_arch = "wasm32")))]
51    TokioHandle(::tokio::task::JoinHandle<T>),
52    #[allow(dead_code)]
53    CustomHandle {
54        inner: Option<futures::channel::oneshot::Receiver<Result<T, Aborted>>>,
55        handle: AbortHandle,
56    },
57    Empty,
58}
59
60impl<T> Default for InnerJoinHandle<T> {
61    fn default() -> Self {
62        Self::Empty
63    }
64}
65
66impl<T> JoinHandle<T> {
67    /// Provide an empty [`JoinHandle`] with no associated task.
68    pub fn empty() -> Self {
69        JoinHandle {
70            inner: InnerJoinHandle::Empty,
71        }
72    }
73}
74
75impl<T> JoinHandle<T> {
76    /// Abort the task associated with the handle.
77    pub fn abort(&self) {
78        match self.inner {
79            #[cfg(all(feature = "tokio", not(target_arch = "wasm32")))]
80            InnerJoinHandle::TokioHandle(ref handle) => handle.abort(),
81            InnerJoinHandle::CustomHandle { ref handle, .. } => handle.abort(),
82            InnerJoinHandle::Empty => {}
83        }
84    }
85
86    /// Check if the task associated with this `JoinHandle` has finished.
87    ///
88    /// Note that this method can return false even if [`JoinHandle::abort`] has been called on the
89    /// task due to the time it may take for the task to cancel.
90    pub fn is_finished(&self) -> bool {
91        match self.inner {
92            #[cfg(all(feature = "tokio", not(target_arch = "wasm32")))]
93            InnerJoinHandle::TokioHandle(ref handle) => handle.is_finished(),
94            InnerJoinHandle::CustomHandle {
95                ref handle,
96                ref inner,
97            } => handle.is_aborted() || inner.is_none(),
98            InnerJoinHandle::Empty => true,
99        }
100    }
101
102    /// Replace the current handle with the provided [`JoinHandle`].
103    ///
104    /// # Safety
105    ///
106    /// Note that if this is called with a non-empty handle, the existing task
107    /// will not be terminated when it is replaced.
108    pub unsafe fn replace(&mut self, mut handle: JoinHandle<T>) {
109        self.inner = std::mem::take(&mut handle.inner);
110    }
111
112    /// Replace the current handle with the provided [`JoinHandle`].
113    ///
114    /// # Safety
115    ///
116    /// Note that if this is called with a non-empty handle, the existing task
117    /// will not be terminated when it is replaced.
118    pub unsafe fn replace_in_place(&mut self, handle: &mut JoinHandle<T>) {
119        self.inner = std::mem::take(&mut handle.inner);
120    }
121}
122
123impl<T> Future for JoinHandle<T> {
124    type Output = std::io::Result<T>;
125    fn poll(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> {
126        let inner = &mut self.inner;
127        match inner {
128            #[cfg(all(feature = "tokio", not(target_arch = "wasm32")))]
129            InnerJoinHandle::TokioHandle(handle) => {
130                let fut = futures::ready!(Pin::new(handle).poll(cx));
131
132                match fut {
133                    Ok(val) => Poll::Ready(Ok(val)),
134                    Err(e) => {
135                        let e = std::io::Error::other(e);
136                        Poll::Ready(Err(e))
137                    }
138                }
139            }
140            InnerJoinHandle::CustomHandle { inner, .. } => {
141                let Some(this) = inner.as_mut() else {
142                    unreachable!("cannot poll a completed future");
143                };
144
145                let fut = futures::ready!(Pin::new(this).poll(cx));
146                inner.take();
147
148                match fut {
149                    Ok(Ok(val)) => Poll::Ready(Ok(val)),
150                    Ok(Err(e)) => {
151                        let e = std::io::Error::other(e);
152                        Poll::Ready(Err(e))
153                    }
154                    Err(e) => {
155                        let e = std::io::Error::other(e);
156                        Poll::Ready(Err(e))
157                    }
158                }
159            }
160            InnerJoinHandle::Empty => {
161                Poll::Ready(Err(std::io::Error::from(std::io::ErrorKind::Other)))
162            }
163        }
164    }
165}
166
167/// The same as [`JoinHandle`] but designed to abort the task when all associated references
168/// to the returned `AbortableJoinHandle` have been dropped.
169#[derive(Clone)]
170pub struct AbortableJoinHandle<T> {
171    handle: Arc<InnerHandle<T>>,
172}
173
174impl<T> Debug for AbortableJoinHandle<T> {
175    fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result {
176        f.debug_struct("AbortableJoinHandle").finish()
177    }
178}
179
180impl<T> From<JoinHandle<T>> for AbortableJoinHandle<T> {
181    fn from(handle: JoinHandle<T>) -> Self {
182        AbortableJoinHandle {
183            handle: Arc::new(InnerHandle {
184                inner: parking_lot::Mutex::new(handle),
185            }),
186        }
187    }
188}
189
190impl<T> AbortableJoinHandle<T> {
191    /// Provide a empty [`AbortableJoinHandle`] with no associated task.
192    pub fn empty() -> Self {
193        Self {
194            handle: Arc::new(InnerHandle {
195                inner: parking_lot::Mutex::new(JoinHandle::empty()),
196            }),
197        }
198    }
199}
200
201impl<T> AbortableJoinHandle<T> {
202    /// See [`JoinHandle::abort`]
203    pub fn abort(&self) {
204        self.handle.inner.lock().abort();
205    }
206
207    /// See [`JoinHandle::is_finished`]
208    pub fn is_finished(&self) -> bool {
209        self.handle.inner.lock().is_finished()
210    }
211
212    /// Replace the current handle with an existing one.
213    ///
214    /// # Safety
215    ///
216    /// Note that if this is called with a non-empty handle, the existing task
217    /// will not be terminated when it is replaced.
218    pub unsafe fn replace(&mut self, inner: AbortableJoinHandle<T>) {
219        let current_handle = &mut *self.handle.inner.lock();
220        let inner_handle = &mut *inner.handle.inner.lock();
221        current_handle.replace_in_place(inner_handle);
222    }
223}
224
225struct InnerHandle<T> {
226    pub inner: parking_lot::Mutex<JoinHandle<T>>,
227}
228
229impl<T> Drop for InnerHandle<T> {
230    fn drop(&mut self) {
231        self.inner.lock().abort();
232    }
233}
234
235impl<T> Future for AbortableJoinHandle<T> {
236    type Output = std::io::Result<T>;
237    fn poll(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> {
238        let inner = &mut *self.handle.inner.lock();
239        Pin::new(inner).poll(cx).map_err(std::io::Error::other)
240    }
241}
242
243/// A task that accepts messages
244pub struct CommunicationTask<T> {
245    _task_handle: AbortableJoinHandle<()>,
246    _channel_tx: futures::channel::mpsc::Sender<T>,
247}
248
249unsafe impl<T: Send> Send for CommunicationTask<T> {}
250unsafe impl<T: Send> Sync for CommunicationTask<T> {}
251
252impl<T> Clone for CommunicationTask<T> {
253    fn clone(&self) -> Self {
254        CommunicationTask {
255            _task_handle: self._task_handle.clone(),
256            _channel_tx: self._channel_tx.clone(),
257        }
258    }
259}
260
261impl<T> Debug for CommunicationTask<T> {
262    fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result {
263        f.debug_struct("CommunicationTask").finish()
264    }
265}
266
267impl<T> CommunicationTask<T>
268where
269    T: Send + Sync + 'static,
270{
271    /// Send a message to the task
272    pub async fn send(&mut self, data: T) -> std::io::Result<()> {
273        self._channel_tx
274            .send(data)
275            .await
276            .map_err(std::io::Error::other)
277    }
278
279    /// Attempts to send a message to the task, returning an error if the channel is full or closed due to the task being aborted.
280    pub fn try_send(&self, data: T) -> std::io::Result<()> {
281        self._channel_tx
282            .clone()
283            .try_send(data)
284            .map_err(std::io::Error::other)
285    }
286
287    /// Abort the task
288    pub fn abort(mut self) {
289        self._channel_tx.close_channel();
290        self._task_handle.abort();
291    }
292
293    /// Check to determine if the task is active.
294    pub fn is_active(&self) -> bool {
295        !self._task_handle.is_finished() && !self._channel_tx.is_closed()
296    }
297}
298
299/// A task that accepts messages
300pub struct UnboundedCommunicationTask<T> {
301    _task_handle: AbortableJoinHandle<()>,
302    _channel_tx: futures::channel::mpsc::UnboundedSender<T>,
303}
304
305unsafe impl<T: Send> Send for UnboundedCommunicationTask<T> {}
306unsafe impl<T: Send> Sync for UnboundedCommunicationTask<T> {}
307
308impl<T> Clone for UnboundedCommunicationTask<T> {
309    fn clone(&self) -> Self {
310        UnboundedCommunicationTask {
311            _task_handle: self._task_handle.clone(),
312            _channel_tx: self._channel_tx.clone(),
313        }
314    }
315}
316
317impl<T> Debug for UnboundedCommunicationTask<T> {
318    fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result {
319        f.debug_struct("UnboundedCommunicationTask").finish()
320    }
321}
322
323impl<T> UnboundedCommunicationTask<T>
324where
325    T: Send + Sync + 'static,
326{
327    /// Send a message to task
328    pub fn send(&mut self, data: T) -> std::io::Result<()> {
329        self._channel_tx
330            .unbounded_send(data)
331            .map_err(std::io::Error::other)
332    }
333
334    /// Abort the task
335    pub fn abort(self) {
336        self._channel_tx.close_channel();
337        self._task_handle.abort();
338    }
339
340    /// Check to determine if the task is active.
341    pub fn is_active(&self) -> bool {
342        !self._task_handle.is_finished() && !self._channel_tx.is_closed()
343    }
344}
345
346pub trait Executor {
347    /// Spawns a new asynchronous task in the background, returning a Future [`JoinHandle`] for it.
348    fn spawn<F>(&self, future: F) -> JoinHandle<F::Output>
349    where
350        F: Future + Send + 'static,
351        F::Output: Send + 'static;
352
353    /// Spawns a new asynchronous task in the background, returning an abortable handle that will cancel the task
354    /// once the handle is dropped.
355    ///
356    /// Note: This function is used if the task is expected to run until the handle is dropped. It is recommended to use
357    /// [`Executor::spawn`] or [`Executor::dispatch`] otherwise.
358    fn spawn_abortable<F>(&self, future: F) -> AbortableJoinHandle<F::Output>
359    where
360        F: Future + Send + 'static,
361        F::Output: Send + 'static,
362    {
363        let handle = self.spawn(future);
364        handle.into()
365    }
366
367    /// Spawns a new asynchronous task in the background without an handle.
368    /// Basically the same as [`Executor::spawn`].
369    fn dispatch<F>(&self, future: F)
370    where
371        F: Future + Send + 'static,
372        F::Output: Send + 'static,
373    {
374        self.spawn(future);
375    }
376
377    /// Spawns a new asynchronous task that accepts messages to the task using [`channels`](futures::channel::mpsc).
378    /// This function returns a handle that allows sending a message, or if there is no reference to the handle at all
379    /// (in other words, all handles are dropped), the task would be aborted.
380    fn spawn_coroutine<T, F, Fut>(&self, mut f: F) -> CommunicationTask<T>
381    where
382        F: FnMut(Receiver<T>) -> Fut,
383        Fut: Future<Output = ()> + Send + 'static,
384    {
385        let (tx, rx) = futures::channel::mpsc::channel(1);
386        let fut = f(rx);
387        let _task_handle = self.spawn_abortable(fut);
388        CommunicationTask {
389            _task_handle,
390            _channel_tx: tx,
391        }
392    }
393
394    /// Spawns a new asynchronous task with provided context that accepts messages to the task using [`channels`](futures::channel::mpsc).
395    /// This function returns a handle that allows sending a message, or if there is no reference to the handle at all
396    /// (in other words, all handles are dropped), the task would be aborted.
397    fn spawn_coroutine_with_context<T, F, C, Fut>(
398        &self,
399        context: C,
400        mut f: F,
401    ) -> CommunicationTask<T>
402    where
403        F: FnMut(C, Receiver<T>) -> Fut,
404        Fut: Future<Output = ()> + Send + 'static,
405    {
406        let (tx, rx) = futures::channel::mpsc::channel(1);
407        let fut = f(context, rx);
408        let _task_handle = self.spawn_abortable(fut);
409        CommunicationTask {
410            _task_handle,
411            _channel_tx: tx,
412        }
413    }
414
415    /// Spawns a new asynchronous task that accepts messages to the task using [`channels`](futures::channel::mpsc).
416    /// This function returns a handle that allows sending a message, or if there is no reference to the handle at all
417    /// (in other words, all handles are dropped), the task would be aborted.
418    fn spawn_unbounded_coroutine<T, F, Fut>(&self, mut f: F) -> UnboundedCommunicationTask<T>
419    where
420        F: FnMut(UnboundedReceiver<T>) -> Fut,
421        Fut: Future<Output = ()> + Send + 'static,
422    {
423        let (tx, rx) = futures::channel::mpsc::unbounded();
424        let fut = f(rx);
425        let _task_handle = self.spawn_abortable(fut);
426        UnboundedCommunicationTask {
427            _task_handle,
428            _channel_tx: tx,
429        }
430    }
431
432    /// Spawns a new asynchronous task with provided context that accepts messages to the task using [`channels`](futures::channel::mpsc).
433    /// This function returns a handle that allows sending a message, or if there is no reference to the handle at all
434    /// (in other words, all handles are dropped), the task would be aborted.
435    fn spawn_unbounded_coroutine_with_context<T, F, C, Fut>(
436        &self,
437        context: C,
438        mut f: F,
439    ) -> UnboundedCommunicationTask<T>
440    where
441        F: FnMut(C, UnboundedReceiver<T>) -> Fut,
442        Fut: Future<Output = ()> + Send + 'static,
443    {
444        let (tx, rx) = futures::channel::mpsc::unbounded();
445        let fut = f(context, rx);
446        let _task_handle = self.spawn_abortable(fut);
447        UnboundedCommunicationTask {
448            _task_handle,
449            _channel_tx: tx,
450        }
451    }
452}
453
454#[cfg(test)]
455mod tests {
456    use crate::{Executor, InnerJoinHandle, JoinHandle};
457    use futures::future::AbortHandle;
458    use std::future::Future;
459
460    async fn task(tx: futures::channel::oneshot::Sender<()>) {
461        futures_timer::Delay::new(std::time::Duration::from_secs(5)).await;
462        let _ = tx.send(());
463        unreachable!();
464    }
465
466    #[test]
467    fn custom_abortable_task() {
468        use futures::future::Abortable;
469        struct FuturesExecutor {
470            pool: futures::executor::ThreadPool,
471        }
472
473        impl Default for FuturesExecutor {
474            fn default() -> Self {
475                Self {
476                    pool: futures::executor::ThreadPool::new().unwrap(),
477                }
478            }
479        }
480
481        impl Executor for FuturesExecutor {
482            fn spawn<F>(&self, future: F) -> JoinHandle<F::Output>
483            where
484                F: Future + Send + 'static,
485                F::Output: Send + 'static,
486            {
487                let (abort_handle, abort_registration) = AbortHandle::new_pair();
488                let future = Abortable::new(future, abort_registration);
489                let (tx, rx) = futures::channel::oneshot::channel();
490                let fut = async {
491                    let val = future.await;
492                    let _ = tx.send(val);
493                };
494
495                self.pool.spawn_ok(fut);
496                let inner = InnerJoinHandle::CustomHandle {
497                    inner: Some(rx),
498                    handle: abort_handle,
499                };
500
501                JoinHandle { inner }
502            }
503        }
504
505        futures::executor::block_on(async move {
506            let executor = FuturesExecutor::default();
507
508            let (tx, rx) = futures::channel::oneshot::channel::<()>();
509            let handle = executor.spawn_abortable(task(tx));
510            drop(handle);
511            let result = rx.await;
512            assert!(result.is_err());
513        });
514    }
515}