1pub mod arc;
2pub mod global;
3pub mod rt;
4pub mod task;
5pub mod tracker;
6
7#[cfg(feature = "either")]
8pub mod either;
9pub mod rc;
10
11use std::fmt::{Debug, Formatter};
12
13use futures::channel::mpsc::{Receiver, UnboundedReceiver};
14use futures::future::{AbortHandle, Aborted};
15use futures::SinkExt;
16use std::future::Future;
17use std::pin::Pin;
18use std::sync::Arc;
19use std::task::{Context, Poll};
20
21#[cfg(all(
22 not(feature = "threadpool"),
23 not(feature = "tokio"),
24 not(target_arch = "wasm32")
25))]
26compile_error!(
27 "At least one runtime (i.e 'tokio', 'threadpool', 'wasm-bindgen-futures') must be enabled"
28);
29
30pub struct JoinHandle<T> {
41 inner: InnerJoinHandle<T>,
42}
43
44impl<T> Debug for JoinHandle<T> {
45 fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result {
46 f.debug_struct("JoinHandle").finish()
47 }
48}
49
50enum InnerJoinHandle<T> {
51 #[cfg(all(feature = "tokio", not(target_arch = "wasm32")))]
52 TokioHandle(::tokio::task::JoinHandle<T>),
53 #[allow(dead_code)]
54 CustomHandle {
55 inner: Option<futures::channel::oneshot::Receiver<Result<T, Aborted>>>,
56 handle: AbortHandle,
57 },
58 Empty,
59}
60
61impl<T> Default for InnerJoinHandle<T> {
62 fn default() -> Self {
63 Self::Empty
64 }
65}
66
67impl<T> JoinHandle<T> {
68 pub fn empty() -> Self {
70 JoinHandle {
71 inner: InnerJoinHandle::Empty,
72 }
73 }
74}
75
76impl<T> JoinHandle<T> {
77 pub fn abort(&self) {
79 match self.inner {
80 #[cfg(all(feature = "tokio", not(target_arch = "wasm32")))]
81 InnerJoinHandle::TokioHandle(ref handle) => handle.abort(),
82 InnerJoinHandle::CustomHandle { ref handle, .. } => handle.abort(),
83 InnerJoinHandle::Empty => {}
84 }
85 }
86
87 pub fn is_finished(&self) -> bool {
92 match self.inner {
93 #[cfg(all(feature = "tokio", not(target_arch = "wasm32")))]
94 InnerJoinHandle::TokioHandle(ref handle) => handle.is_finished(),
95 InnerJoinHandle::CustomHandle {
96 ref handle,
97 ref inner,
98 } => handle.is_aborted() || inner.is_none(),
99 InnerJoinHandle::Empty => true,
100 }
101 }
102
103 pub unsafe fn replace(&mut self, mut handle: JoinHandle<T>) {
110 self.inner = std::mem::take(&mut handle.inner);
111 }
112
113 pub unsafe fn replace_in_place(&mut self, handle: &mut JoinHandle<T>) {
120 self.inner = std::mem::take(&mut handle.inner);
121 }
122}
123
124impl<T> Future for JoinHandle<T> {
125 type Output = std::io::Result<T>;
126 fn poll(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> {
127 let inner = &mut self.inner;
128 match inner {
129 #[cfg(all(feature = "tokio", not(target_arch = "wasm32")))]
130 InnerJoinHandle::TokioHandle(handle) => {
131 let fut = futures::ready!(Pin::new(handle).poll(cx));
132
133 match fut {
134 Ok(val) => Poll::Ready(Ok(val)),
135 Err(e) => {
136 let e = std::io::Error::other(e);
137 Poll::Ready(Err(e))
138 }
139 }
140 }
141 InnerJoinHandle::CustomHandle { inner, .. } => {
142 let Some(this) = inner.as_mut() else {
143 unreachable!("cannot poll a completed future");
144 };
145
146 let fut = futures::ready!(Pin::new(this).poll(cx));
147 inner.take();
148
149 match fut {
150 Ok(Ok(val)) => Poll::Ready(Ok(val)),
151 Ok(Err(e)) => {
152 let e = std::io::Error::other(e);
153 Poll::Ready(Err(e))
154 }
155 Err(e) => {
156 let e = std::io::Error::other(e);
157 Poll::Ready(Err(e))
158 }
159 }
160 }
161 InnerJoinHandle::Empty => {
162 Poll::Ready(Err(std::io::Error::from(std::io::ErrorKind::Other)))
163 }
164 }
165 }
166}
167
168#[derive(Clone)]
171pub struct AbortableJoinHandle<T> {
172 handle: Arc<InnerHandle<T>>,
173}
174
175impl<T> Debug for AbortableJoinHandle<T> {
176 fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result {
177 f.debug_struct("AbortableJoinHandle").finish()
178 }
179}
180
181impl<T> From<JoinHandle<T>> for AbortableJoinHandle<T> {
182 fn from(handle: JoinHandle<T>) -> Self {
183 AbortableJoinHandle {
184 handle: Arc::new(InnerHandle {
185 inner: parking_lot::Mutex::new(handle),
186 }),
187 }
188 }
189}
190
191impl<T> AbortableJoinHandle<T> {
192 pub fn empty() -> Self {
194 Self {
195 handle: Arc::new(InnerHandle {
196 inner: parking_lot::Mutex::new(JoinHandle::empty()),
197 }),
198 }
199 }
200}
201
202impl<T> AbortableJoinHandle<T> {
203 pub fn abort(&self) {
205 self.handle.inner.lock().abort();
206 }
207
208 pub fn is_finished(&self) -> bool {
210 self.handle.inner.lock().is_finished()
211 }
212
213 pub unsafe fn replace(&mut self, inner: AbortableJoinHandle<T>) {
220 let current_handle = &mut *self.handle.inner.lock();
221 let inner_handle = &mut *inner.handle.inner.lock();
222 current_handle.replace_in_place(inner_handle);
223 }
224}
225
226struct InnerHandle<T> {
227 pub inner: parking_lot::Mutex<JoinHandle<T>>,
228}
229
230impl<T> Drop for InnerHandle<T> {
231 fn drop(&mut self) {
232 self.inner.lock().abort();
233 }
234}
235
236impl<T> Future for AbortableJoinHandle<T> {
237 type Output = std::io::Result<T>;
238 fn poll(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> {
239 let inner = &mut *self.handle.inner.lock();
240 Pin::new(inner).poll(cx).map_err(std::io::Error::other)
241 }
242}
243
244pub struct CommunicationTask<T> {
246 _task_handle: AbortableJoinHandle<()>,
247 _channel_tx: futures::channel::mpsc::Sender<T>,
248}
249
250unsafe impl<T: Send> Send for CommunicationTask<T> {}
251unsafe impl<T: Send> Sync for CommunicationTask<T> {}
252
253impl<T> Clone for CommunicationTask<T> {
254 fn clone(&self) -> Self {
255 CommunicationTask {
256 _task_handle: self._task_handle.clone(),
257 _channel_tx: self._channel_tx.clone(),
258 }
259 }
260}
261
262impl<T> Debug for CommunicationTask<T> {
263 fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result {
264 f.debug_struct("CommunicationTask").finish()
265 }
266}
267
268impl<T> CommunicationTask<T>
269where
270 T: 'static,
271{
272 pub async fn send(&mut self, data: T) -> std::io::Result<()> {
274 self._channel_tx
275 .send(data)
276 .await
277 .map_err(std::io::Error::other)
278 }
279
280 pub fn try_send(&self, data: T) -> std::io::Result<()>
282 where
283 T: Send + Sync,
284 {
285 self._channel_tx
286 .clone()
287 .try_send(data)
288 .map_err(std::io::Error::other)
289 }
290
291 pub fn abort(mut self) {
293 self._channel_tx.close_channel();
294 self._task_handle.abort();
295 }
296
297 pub fn is_active(&self) -> bool {
299 !self._task_handle.is_finished() && !self._channel_tx.is_closed()
300 }
301}
302
303pub struct UnboundedCommunicationTask<T> {
305 _task_handle: AbortableJoinHandle<()>,
306 _channel_tx: futures::channel::mpsc::UnboundedSender<T>,
307}
308
309unsafe impl<T: Send> Send for UnboundedCommunicationTask<T> {}
310unsafe impl<T: Send> Sync for UnboundedCommunicationTask<T> {}
311
312impl<T> Clone for UnboundedCommunicationTask<T> {
313 fn clone(&self) -> Self {
314 UnboundedCommunicationTask {
315 _task_handle: self._task_handle.clone(),
316 _channel_tx: self._channel_tx.clone(),
317 }
318 }
319}
320
321impl<T> Debug for UnboundedCommunicationTask<T> {
322 fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result {
323 f.debug_struct("UnboundedCommunicationTask").finish()
324 }
325}
326
327impl<T> UnboundedCommunicationTask<T>
328where
329 T: 'static,
330{
331 pub fn send(&mut self, data: T) -> std::io::Result<()>
333 where
334 T: Send + Sync,
335 {
336 self._channel_tx
337 .unbounded_send(data)
338 .map_err(std::io::Error::other)
339 }
340
341 pub fn abort(self) {
343 self._channel_tx.close_channel();
344 self._task_handle.abort();
345 }
346
347 pub fn is_active(&self) -> bool {
349 !self._task_handle.is_finished() && !self._channel_tx.is_closed()
350 }
351}
352
353pub trait Executor {
354 fn spawn<F>(&self, future: F) -> JoinHandle<F::Output>
356 where
357 F: Future + Send + 'static,
358 F::Output: Send + 'static;
359
360 fn spawn_abortable<F>(&self, future: F) -> AbortableJoinHandle<F::Output>
366 where
367 F: Future + Send + 'static,
368 F::Output: Send + 'static,
369 {
370 let handle = self.spawn(future);
371 handle.into()
372 }
373
374 fn dispatch<F>(&self, future: F)
377 where
378 F: Future + Send + 'static,
379 F::Output: Send + 'static,
380 {
381 self.spawn(future);
382 }
383
384 fn spawn_coroutine<T, F, Fut>(&self, f: F) -> CommunicationTask<T>
388 where
389 F: FnMut(Receiver<T>) -> Fut,
390 Fut: Future<Output = ()> + Send + 'static,
391 {
392 Self::spawn_coroutine_with_buffer(self, 1, f)
393 }
394
395 fn spawn_coroutine_with_buffer<T, F, Fut>(
399 &self,
400 buffer: usize,
401 mut f: F,
402 ) -> CommunicationTask<T>
403 where
404 F: FnMut(Receiver<T>) -> Fut,
405 Fut: Future<Output = ()> + Send + 'static,
406 {
407 let (tx, rx) = futures::channel::mpsc::channel(buffer);
408 let fut = f(rx);
409 let _task_handle = self.spawn_abortable(fut);
410 CommunicationTask {
411 _task_handle,
412 _channel_tx: tx,
413 }
414 }
415
416 fn spawn_coroutine_with_context<T, F, C, Fut>(&self, context: C, f: F) -> CommunicationTask<T>
420 where
421 F: FnMut(C, Receiver<T>) -> Fut,
422 Fut: Future<Output = ()> + Send + 'static,
423 {
424 Self::spawn_coroutine_with_buffer_and_context(self, context, 1, f)
425 }
426
427 fn spawn_coroutine_with_buffer_and_context<T, F, C, Fut>(
431 &self,
432 context: C,
433 buffer: usize,
434 mut f: F,
435 ) -> CommunicationTask<T>
436 where
437 F: FnMut(C, Receiver<T>) -> Fut,
438 Fut: Future<Output = ()> + Send + 'static,
439 {
440 let (tx, rx) = futures::channel::mpsc::channel(buffer);
441 let fut = f(context, rx);
442 let _task_handle = self.spawn_abortable(fut);
443 CommunicationTask {
444 _task_handle,
445 _channel_tx: tx,
446 }
447 }
448
449 fn spawn_unbounded_coroutine<T, F, Fut>(&self, mut f: F) -> UnboundedCommunicationTask<T>
453 where
454 F: FnMut(UnboundedReceiver<T>) -> Fut,
455 Fut: Future<Output = ()> + Send + 'static,
456 {
457 let (tx, rx) = futures::channel::mpsc::unbounded();
458 let fut = f(rx);
459 let _task_handle = self.spawn_abortable(fut);
460 UnboundedCommunicationTask {
461 _task_handle,
462 _channel_tx: tx,
463 }
464 }
465
466 fn spawn_unbounded_coroutine_with_context<T, F, C, Fut>(
470 &self,
471 context: C,
472 mut f: F,
473 ) -> UnboundedCommunicationTask<T>
474 where
475 F: FnMut(C, UnboundedReceiver<T>) -> Fut,
476 Fut: Future<Output = ()> + Send + 'static,
477 {
478 let (tx, rx) = futures::channel::mpsc::unbounded();
479 let fut = f(context, rx);
480 let _task_handle = self.spawn_abortable(fut);
481 UnboundedCommunicationTask {
482 _task_handle,
483 _channel_tx: tx,
484 }
485 }
486}
487
488pub trait ExecutorBlocking: Executor {
489 fn spawn_blocking<F, R>(&self, f: F) -> JoinHandle<R>
495 where
496 F: FnOnce() -> R + Send + 'static,
497 R: Send + 'static;
498}
499
500#[cfg(test)]
501mod tests {
502 use crate::{Executor, ExecutorBlocking, InnerJoinHandle, JoinHandle};
503 use futures::future::AbortHandle;
504 use std::future::Future;
505
506 async fn task(tx: futures::channel::oneshot::Sender<()>) {
507 futures_timer::Delay::new(std::time::Duration::from_secs(5)).await;
508 let _ = tx.send(());
509 unreachable!();
510 }
511
512 #[test]
513 fn custom_abortable_task() {
514 use futures::future::Abortable;
515 struct FuturesExecutor {
516 pool: futures::executor::ThreadPool,
517 }
518
519 impl Default for FuturesExecutor {
520 fn default() -> Self {
521 Self {
522 pool: futures::executor::ThreadPool::new().unwrap(),
523 }
524 }
525 }
526
527 impl Executor for FuturesExecutor {
528 fn spawn<F>(&self, future: F) -> JoinHandle<F::Output>
529 where
530 F: Future + Send + 'static,
531 F::Output: Send + 'static,
532 {
533 let (abort_handle, abort_registration) = AbortHandle::new_pair();
534 let future = Abortable::new(future, abort_registration);
535 let (tx, rx) = futures::channel::oneshot::channel();
536 let fut = async {
537 let val = future.await;
538 let _ = tx.send(val);
539 };
540
541 self.pool.spawn_ok(fut);
542 let inner = InnerJoinHandle::CustomHandle {
543 inner: Some(rx),
544 handle: abort_handle,
545 };
546
547 JoinHandle { inner }
548 }
549 }
550
551 impl ExecutorBlocking for FuturesExecutor {
552 fn spawn_blocking<F, R>(&self, _: F) -> JoinHandle<R>
553 where
554 F: FnOnce() -> R + Send + 'static,
555 R: Send + 'static,
556 {
557 unimplemented!()
558 }
559 }
560
561 futures::executor::block_on(async move {
562 let executor = FuturesExecutor::default();
563
564 let (tx, rx) = futures::channel::oneshot::channel::<()>();
565 let handle = executor.spawn_abortable(task(tx));
566 drop(handle);
567 let result = rx.await;
568 assert!(result.is_err());
569 });
570 }
571}