Skip to main content

async_rt/
lib.rs

1pub mod arc;
2pub mod global;
3pub mod rt;
4pub mod task;
5pub mod tracker;
6
7#[cfg(feature = "either")]
8pub mod either;
9pub mod rc;
10
11use std::fmt::{Debug, Formatter};
12
13use futures::channel::mpsc::{Receiver, UnboundedReceiver};
14use futures::future::{AbortHandle, Aborted};
15use futures::SinkExt;
16use std::future::Future;
17use std::pin::Pin;
18use std::sync::Arc;
19use std::task::{Context, Poll};
20
21#[cfg(all(
22    not(feature = "threadpool"),
23    not(feature = "tokio"),
24    not(target_arch = "wasm32")
25))]
26compile_error!(
27    "At least one runtime (i.e 'tokio', 'threadpool', 'wasm-bindgen-futures') must be enabled"
28);
29
30/// An owned permission to join on a task (await its termination).
31///
32/// This can be seen as an equivalent to [`std::thread::JoinHandle`] but for [`Future`] tasks rather than a thread.
33/// Note that the task associated with this `JoinHandle` will start running at the time [`Executor::spawn`] is called as
34/// well as according to the implemented runtime (i.e. [`tokio`]), even if `JoinHandle` has not been awaited.
35///
36/// Dropping `JoinHandle` will not abort or cancel the task. In other words, the task will continue to run in the background
37/// and any return value will be lost.
38///
39/// This `struct` is created by the [`Executor::spawn`].
40pub struct JoinHandle<T> {
41    inner: InnerJoinHandle<T>,
42}
43
44impl<T> Debug for JoinHandle<T> {
45    fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result {
46        f.debug_struct("JoinHandle").finish()
47    }
48}
49
50enum InnerJoinHandle<T> {
51    #[cfg(all(feature = "tokio", not(target_arch = "wasm32")))]
52    TokioHandle(::tokio::task::JoinHandle<T>),
53    #[allow(dead_code)]
54    CustomHandle {
55        inner: Option<futures::channel::oneshot::Receiver<Result<T, Aborted>>>,
56        handle: AbortHandle,
57    },
58    Empty,
59}
60
61impl<T> Default for InnerJoinHandle<T> {
62    fn default() -> Self {
63        Self::Empty
64    }
65}
66
67impl<T> JoinHandle<T> {
68    /// Provide an empty [`JoinHandle`] with no associated task.
69    pub fn empty() -> Self {
70        JoinHandle {
71            inner: InnerJoinHandle::Empty,
72        }
73    }
74}
75
76impl<T> JoinHandle<T> {
77    /// Abort the task associated with the handle.
78    pub fn abort(&self) {
79        match self.inner {
80            #[cfg(all(feature = "tokio", not(target_arch = "wasm32")))]
81            InnerJoinHandle::TokioHandle(ref handle) => handle.abort(),
82            InnerJoinHandle::CustomHandle { ref handle, .. } => handle.abort(),
83            InnerJoinHandle::Empty => {}
84        }
85    }
86
87    /// Check if the task associated with this `JoinHandle` has finished.
88    ///
89    /// Note that this method can return false even if [`JoinHandle::abort`] has been called on the
90    /// task due to the time it may take for the task to cancel.
91    pub fn is_finished(&self) -> bool {
92        match self.inner {
93            #[cfg(all(feature = "tokio", not(target_arch = "wasm32")))]
94            InnerJoinHandle::TokioHandle(ref handle) => handle.is_finished(),
95            InnerJoinHandle::CustomHandle {
96                ref handle,
97                ref inner,
98            } => handle.is_aborted() || inner.is_none(),
99            InnerJoinHandle::Empty => true,
100        }
101    }
102
103    /// Replace the current handle with the provided [`JoinHandle`].
104    ///
105    /// # Safety
106    ///
107    /// Note that if this is called with a non-empty handle, the existing task
108    /// will not be terminated when it is replaced.
109    pub unsafe fn replace(&mut self, mut handle: JoinHandle<T>) {
110        self.inner = std::mem::take(&mut handle.inner);
111    }
112
113    /// Replace the current handle with the provided [`JoinHandle`].
114    ///
115    /// # Safety
116    ///
117    /// Note that if this is called with a non-empty handle, the existing task
118    /// will not be terminated when it is replaced.
119    pub unsafe fn replace_in_place(&mut self, handle: &mut JoinHandle<T>) {
120        self.inner = std::mem::take(&mut handle.inner);
121    }
122}
123
124impl<T> Future for JoinHandle<T> {
125    type Output = std::io::Result<T>;
126    fn poll(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> {
127        let inner = &mut self.inner;
128        match inner {
129            #[cfg(all(feature = "tokio", not(target_arch = "wasm32")))]
130            InnerJoinHandle::TokioHandle(handle) => {
131                let fut = futures::ready!(Pin::new(handle).poll(cx));
132
133                match fut {
134                    Ok(val) => Poll::Ready(Ok(val)),
135                    Err(e) => {
136                        let e = std::io::Error::other(e);
137                        Poll::Ready(Err(e))
138                    }
139                }
140            }
141            InnerJoinHandle::CustomHandle { inner, .. } => {
142                let Some(this) = inner.as_mut() else {
143                    unreachable!("cannot poll a completed future");
144                };
145
146                let fut = futures::ready!(Pin::new(this).poll(cx));
147                inner.take();
148
149                match fut {
150                    Ok(Ok(val)) => Poll::Ready(Ok(val)),
151                    Ok(Err(e)) => {
152                        let e = std::io::Error::other(e);
153                        Poll::Ready(Err(e))
154                    }
155                    Err(e) => {
156                        let e = std::io::Error::other(e);
157                        Poll::Ready(Err(e))
158                    }
159                }
160            }
161            InnerJoinHandle::Empty => {
162                Poll::Ready(Err(std::io::Error::from(std::io::ErrorKind::Other)))
163            }
164        }
165    }
166}
167
168/// The same as [`JoinHandle`] but designed to abort the task when all associated references
169/// to the returned `AbortableJoinHandle` have been dropped.
170#[derive(Clone)]
171pub struct AbortableJoinHandle<T> {
172    handle: Arc<InnerHandle<T>>,
173}
174
175impl<T> Debug for AbortableJoinHandle<T> {
176    fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result {
177        f.debug_struct("AbortableJoinHandle").finish()
178    }
179}
180
181impl<T> From<JoinHandle<T>> for AbortableJoinHandle<T> {
182    fn from(handle: JoinHandle<T>) -> Self {
183        AbortableJoinHandle {
184            handle: Arc::new(InnerHandle {
185                inner: parking_lot::Mutex::new(handle),
186            }),
187        }
188    }
189}
190
191impl<T> AbortableJoinHandle<T> {
192    /// Provide a empty [`AbortableJoinHandle`] with no associated task.
193    pub fn empty() -> Self {
194        Self {
195            handle: Arc::new(InnerHandle {
196                inner: parking_lot::Mutex::new(JoinHandle::empty()),
197            }),
198        }
199    }
200}
201
202impl<T> AbortableJoinHandle<T> {
203    /// See [`JoinHandle::abort`]
204    pub fn abort(&self) {
205        self.handle.inner.lock().abort();
206    }
207
208    /// See [`JoinHandle::is_finished`]
209    pub fn is_finished(&self) -> bool {
210        self.handle.inner.lock().is_finished()
211    }
212
213    /// Replace the current handle with an existing one.
214    ///
215    /// # Safety
216    ///
217    /// Note that if this is called with a non-empty handle, the existing task
218    /// will not be terminated when it is replaced.
219    pub unsafe fn replace(&mut self, inner: AbortableJoinHandle<T>) {
220        let current_handle = &mut *self.handle.inner.lock();
221        let inner_handle = &mut *inner.handle.inner.lock();
222        current_handle.replace_in_place(inner_handle);
223    }
224}
225
226struct InnerHandle<T> {
227    pub inner: parking_lot::Mutex<JoinHandle<T>>,
228}
229
230impl<T> Drop for InnerHandle<T> {
231    fn drop(&mut self) {
232        self.inner.lock().abort();
233    }
234}
235
236impl<T> Future for AbortableJoinHandle<T> {
237    type Output = std::io::Result<T>;
238    fn poll(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> {
239        let inner = &mut *self.handle.inner.lock();
240        Pin::new(inner).poll(cx).map_err(std::io::Error::other)
241    }
242}
243
244/// A task that accepts messages
245pub struct CommunicationTask<T> {
246    _task_handle: AbortableJoinHandle<()>,
247    _channel_tx: futures::channel::mpsc::Sender<T>,
248}
249
250unsafe impl<T: Send> Send for CommunicationTask<T> {}
251unsafe impl<T: Send> Sync for CommunicationTask<T> {}
252
253impl<T> Clone for CommunicationTask<T> {
254    fn clone(&self) -> Self {
255        CommunicationTask {
256            _task_handle: self._task_handle.clone(),
257            _channel_tx: self._channel_tx.clone(),
258        }
259    }
260}
261
262impl<T> Debug for CommunicationTask<T> {
263    fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result {
264        f.debug_struct("CommunicationTask").finish()
265    }
266}
267
268impl<T> CommunicationTask<T>
269where
270    T: 'static,
271{
272    /// Send a message to the task
273    pub async fn send(&mut self, data: T) -> std::io::Result<()> {
274        self._channel_tx
275            .send(data)
276            .await
277            .map_err(std::io::Error::other)
278    }
279
280    /// Attempts to send a message to the task, returning an error if the channel is full or closed due to the task being aborted.
281    pub fn try_send(&self, data: T) -> std::io::Result<()>
282    where
283        T: Send + Sync,
284    {
285        self._channel_tx
286            .clone()
287            .try_send(data)
288            .map_err(std::io::Error::other)
289    }
290
291    /// Abort the task
292    pub fn abort(mut self) {
293        self._channel_tx.close_channel();
294        self._task_handle.abort();
295    }
296
297    /// Check to determine if the task is active.
298    pub fn is_active(&self) -> bool {
299        !self._task_handle.is_finished() && !self._channel_tx.is_closed()
300    }
301}
302
303/// A task that accepts messages
304pub struct UnboundedCommunicationTask<T> {
305    _task_handle: AbortableJoinHandle<()>,
306    _channel_tx: futures::channel::mpsc::UnboundedSender<T>,
307}
308
309unsafe impl<T: Send> Send for UnboundedCommunicationTask<T> {}
310unsafe impl<T: Send> Sync for UnboundedCommunicationTask<T> {}
311
312impl<T> Clone for UnboundedCommunicationTask<T> {
313    fn clone(&self) -> Self {
314        UnboundedCommunicationTask {
315            _task_handle: self._task_handle.clone(),
316            _channel_tx: self._channel_tx.clone(),
317        }
318    }
319}
320
321impl<T> Debug for UnboundedCommunicationTask<T> {
322    fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result {
323        f.debug_struct("UnboundedCommunicationTask").finish()
324    }
325}
326
327impl<T> UnboundedCommunicationTask<T>
328where
329    T: 'static,
330{
331    /// Send a message to task
332    pub fn send(&mut self, data: T) -> std::io::Result<()>
333    where
334        T: Send + Sync,
335    {
336        self._channel_tx
337            .unbounded_send(data)
338            .map_err(std::io::Error::other)
339    }
340
341    /// Abort the task
342    pub fn abort(self) {
343        self._channel_tx.close_channel();
344        self._task_handle.abort();
345    }
346
347    /// Check to determine if the task is active.
348    pub fn is_active(&self) -> bool {
349        !self._task_handle.is_finished() && !self._channel_tx.is_closed()
350    }
351}
352
353pub trait Executor {
354    /// Spawns a new asynchronous task in the background, returning a Future [`JoinHandle`] for it.
355    fn spawn<F>(&self, future: F) -> JoinHandle<F::Output>
356    where
357        F: Future + Send + 'static,
358        F::Output: Send + 'static;
359
360    /// Spawns a new asynchronous task in the background, returning an abortable handle that will cancel the task
361    /// once the handle is dropped.
362    ///
363    /// Note: This function is used if the task is expected to run until the handle is dropped. It is recommended to use
364    /// [`Executor::spawn`] or [`Executor::dispatch`] otherwise.
365    fn spawn_abortable<F>(&self, future: F) -> AbortableJoinHandle<F::Output>
366    where
367        F: Future + Send + 'static,
368        F::Output: Send + 'static,
369    {
370        let handle = self.spawn(future);
371        handle.into()
372    }
373
374    /// Spawns a new asynchronous task in the background without an handle.
375    /// Basically the same as [`Executor::spawn`].
376    fn dispatch<F>(&self, future: F)
377    where
378        F: Future + Send + 'static,
379        F::Output: Send + 'static,
380    {
381        self.spawn(future);
382    }
383
384    /// Spawns a new asynchronous task that accepts messages to the task using [`channels`](futures::channel::mpsc).
385    /// This function returns a handle that allows sending a message, or if there is no reference to the handle at all
386    /// (in other words, all handles are dropped), the task would be aborted.
387    fn spawn_coroutine<T, F, Fut>(&self, f: F) -> CommunicationTask<T>
388    where
389        F: FnMut(Receiver<T>) -> Fut,
390        Fut: Future<Output = ()> + Send + 'static,
391    {
392        Self::spawn_coroutine_with_buffer(self, 1, f)
393    }
394
395    /// Spawns a new asynchronous task with a set channel buffer that accepts messages to the task using [`channels`](futures::channel::mpsc).
396    /// This function returns a handle that allows sending a message, or if there is no reference to the handle at all
397    /// (in other words, all handles are dropped), the task would be aborted.
398    fn spawn_coroutine_with_buffer<T, F, Fut>(
399        &self,
400        buffer: usize,
401        mut f: F,
402    ) -> CommunicationTask<T>
403    where
404        F: FnMut(Receiver<T>) -> Fut,
405        Fut: Future<Output = ()> + Send + 'static,
406    {
407        let (tx, rx) = futures::channel::mpsc::channel(buffer);
408        let fut = f(rx);
409        let _task_handle = self.spawn_abortable(fut);
410        CommunicationTask {
411            _task_handle,
412            _channel_tx: tx,
413        }
414    }
415
416    /// Spawns a new asynchronous task with provided context that accepts messages to the task using [`channels`](futures::channel::mpsc).
417    /// This function returns a handle that allows sending a message, or if there is no reference to the handle at all
418    /// (in other words, all handles are dropped), the task would be aborted.
419    fn spawn_coroutine_with_context<T, F, C, Fut>(&self, context: C, f: F) -> CommunicationTask<T>
420    where
421        F: FnMut(C, Receiver<T>) -> Fut,
422        Fut: Future<Output = ()> + Send + 'static,
423    {
424        Self::spawn_coroutine_with_buffer_and_context(self, context, 1, f)
425    }
426
427    /// Spawns a new asynchronous task with a set channel buffer and provided context that accepts messages to the task using [`channels`](futures::channel::mpsc).
428    /// This function returns a handle that allows sending a message, or if there is no reference to the handle at all
429    /// (in other words, all handles are dropped), the task would be aborted.
430    fn spawn_coroutine_with_buffer_and_context<T, F, C, Fut>(
431        &self,
432        context: C,
433        buffer: usize,
434        mut f: F,
435    ) -> CommunicationTask<T>
436    where
437        F: FnMut(C, Receiver<T>) -> Fut,
438        Fut: Future<Output = ()> + Send + 'static,
439    {
440        let (tx, rx) = futures::channel::mpsc::channel(buffer);
441        let fut = f(context, rx);
442        let _task_handle = self.spawn_abortable(fut);
443        CommunicationTask {
444            _task_handle,
445            _channel_tx: tx,
446        }
447    }
448
449    /// Spawns a new asynchronous task that accepts messages to the task using [`channels`](futures::channel::mpsc).
450    /// This function returns a handle that allows sending a message, or if there is no reference to the handle at all
451    /// (in other words, all handles are dropped), the task would be aborted.
452    fn spawn_unbounded_coroutine<T, F, Fut>(&self, mut f: F) -> UnboundedCommunicationTask<T>
453    where
454        F: FnMut(UnboundedReceiver<T>) -> Fut,
455        Fut: Future<Output = ()> + Send + 'static,
456    {
457        let (tx, rx) = futures::channel::mpsc::unbounded();
458        let fut = f(rx);
459        let _task_handle = self.spawn_abortable(fut);
460        UnboundedCommunicationTask {
461            _task_handle,
462            _channel_tx: tx,
463        }
464    }
465
466    /// Spawns a new asynchronous task with provided context that accepts messages to the task using [`channels`](futures::channel::mpsc).
467    /// This function returns a handle that allows sending a message, or if there is no reference to the handle at all
468    /// (in other words, all handles are dropped), the task would be aborted.
469    fn spawn_unbounded_coroutine_with_context<T, F, C, Fut>(
470        &self,
471        context: C,
472        mut f: F,
473    ) -> UnboundedCommunicationTask<T>
474    where
475        F: FnMut(C, UnboundedReceiver<T>) -> Fut,
476        Fut: Future<Output = ()> + Send + 'static,
477    {
478        let (tx, rx) = futures::channel::mpsc::unbounded();
479        let fut = f(context, rx);
480        let _task_handle = self.spawn_abortable(fut);
481        UnboundedCommunicationTask {
482            _task_handle,
483            _channel_tx: tx,
484        }
485    }
486}
487
488pub trait ExecutorBlocking: Executor {
489    /// Spawn a thread in the background with the ability to concurrently await on the results without
490    /// blocking the executor.
491    ///
492    /// Note that there is no real way to abort the thread, and calling [`JoinHandle::abort`]
493    /// would only abort the underlining tasked that is being polled but not the thread itself.
494    fn spawn_blocking<F, R>(&self, f: F) -> JoinHandle<R>
495    where
496        F: FnOnce() -> R + Send + 'static,
497        R: Send + 'static;
498}
499
500#[cfg(test)]
501mod tests {
502    use crate::{Executor, ExecutorBlocking, InnerJoinHandle, JoinHandle};
503    use futures::future::AbortHandle;
504    use std::future::Future;
505
506    async fn task(tx: futures::channel::oneshot::Sender<()>) {
507        futures_timer::Delay::new(std::time::Duration::from_secs(5)).await;
508        let _ = tx.send(());
509        unreachable!();
510    }
511
512    #[test]
513    fn custom_abortable_task() {
514        use futures::future::Abortable;
515        struct FuturesExecutor {
516            pool: futures::executor::ThreadPool,
517        }
518
519        impl Default for FuturesExecutor {
520            fn default() -> Self {
521                Self {
522                    pool: futures::executor::ThreadPool::new().unwrap(),
523                }
524            }
525        }
526
527        impl Executor for FuturesExecutor {
528            fn spawn<F>(&self, future: F) -> JoinHandle<F::Output>
529            where
530                F: Future + Send + 'static,
531                F::Output: Send + 'static,
532            {
533                let (abort_handle, abort_registration) = AbortHandle::new_pair();
534                let future = Abortable::new(future, abort_registration);
535                let (tx, rx) = futures::channel::oneshot::channel();
536                let fut = async {
537                    let val = future.await;
538                    let _ = tx.send(val);
539                };
540
541                self.pool.spawn_ok(fut);
542                let inner = InnerJoinHandle::CustomHandle {
543                    inner: Some(rx),
544                    handle: abort_handle,
545                };
546
547                JoinHandle { inner }
548            }
549        }
550
551        impl ExecutorBlocking for FuturesExecutor {
552            fn spawn_blocking<F, R>(&self, _: F) -> JoinHandle<R>
553            where
554                F: FnOnce() -> R + Send + 'static,
555                R: Send + 'static,
556            {
557                unimplemented!()
558            }
559        }
560
561        futures::executor::block_on(async move {
562            let executor = FuturesExecutor::default();
563
564            let (tx, rx) = futures::channel::oneshot::channel::<()>();
565            let handle = executor.spawn_abortable(task(tx));
566            drop(handle);
567            let result = rx.await;
568            assert!(result.is_err());
569        });
570    }
571}