1pub mod arc;
2pub mod global;
3pub mod rt;
4pub mod task;
5pub mod tracker;
6
7#[cfg(feature = "either")]
8mod either;
9
10use std::fmt::{Debug, Formatter};
11
12use futures::channel::mpsc::{Receiver, UnboundedReceiver};
13use futures::future::{AbortHandle, Aborted};
14use futures::SinkExt;
15use std::future::Future;
16use std::pin::Pin;
17use std::sync::Arc;
18use std::task::{Context, Poll};
19
20#[cfg(all(
21 not(feature = "threadpool"),
22 not(feature = "tokio"),
23 not(target_arch = "wasm32")
24))]
25compile_error!(
26 "At least one runtime (i.e 'tokio', 'threadpool', 'wasm-bindgen-futures') must be enabled"
27);
28
29pub struct JoinHandle<T> {
40 inner: InnerJoinHandle<T>,
41}
42
43impl<T> Debug for JoinHandle<T> {
44 fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result {
45 f.debug_struct("JoinHandle").finish()
46 }
47}
48
49enum InnerJoinHandle<T> {
50 #[cfg(all(feature = "tokio", not(target_arch = "wasm32")))]
51 TokioHandle(::tokio::task::JoinHandle<T>),
52 #[allow(dead_code)]
53 CustomHandle {
54 inner: Option<futures::channel::oneshot::Receiver<Result<T, Aborted>>>,
55 handle: AbortHandle,
56 },
57 Empty,
58}
59
60impl<T> Default for InnerJoinHandle<T> {
61 fn default() -> Self {
62 Self::Empty
63 }
64}
65
66impl<T> JoinHandle<T> {
67 pub fn empty() -> Self {
69 JoinHandle {
70 inner: InnerJoinHandle::Empty,
71 }
72 }
73}
74
75impl<T> JoinHandle<T> {
76 pub fn abort(&self) {
78 match self.inner {
79 #[cfg(all(feature = "tokio", not(target_arch = "wasm32")))]
80 InnerJoinHandle::TokioHandle(ref handle) => handle.abort(),
81 InnerJoinHandle::CustomHandle { ref handle, .. } => handle.abort(),
82 InnerJoinHandle::Empty => {}
83 }
84 }
85
86 pub fn is_finished(&self) -> bool {
91 match self.inner {
92 #[cfg(all(feature = "tokio", not(target_arch = "wasm32")))]
93 InnerJoinHandle::TokioHandle(ref handle) => handle.is_finished(),
94 InnerJoinHandle::CustomHandle {
95 ref handle,
96 ref inner,
97 } => handle.is_aborted() || inner.is_none(),
98 InnerJoinHandle::Empty => true,
99 }
100 }
101
102 pub unsafe fn replace(&mut self, mut handle: JoinHandle<T>) {
109 self.inner = std::mem::take(&mut handle.inner);
110 }
111
112 pub unsafe fn replace_in_place(&mut self, handle: &mut JoinHandle<T>) {
119 self.inner = std::mem::take(&mut handle.inner);
120 }
121}
122
123impl<T> Future for JoinHandle<T> {
124 type Output = std::io::Result<T>;
125 fn poll(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> {
126 let inner = &mut self.inner;
127 match inner {
128 #[cfg(all(feature = "tokio", not(target_arch = "wasm32")))]
129 InnerJoinHandle::TokioHandle(handle) => {
130 let fut = futures::ready!(Pin::new(handle).poll(cx));
131
132 match fut {
133 Ok(val) => Poll::Ready(Ok(val)),
134 Err(e) => {
135 let e = std::io::Error::other(e);
136 Poll::Ready(Err(e))
137 }
138 }
139 }
140 InnerJoinHandle::CustomHandle { inner, .. } => {
141 let Some(this) = inner.as_mut() else {
142 unreachable!("cannot poll a completed future");
143 };
144
145 let fut = futures::ready!(Pin::new(this).poll(cx));
146 inner.take();
147
148 match fut {
149 Ok(Ok(val)) => Poll::Ready(Ok(val)),
150 Ok(Err(e)) => {
151 let e = std::io::Error::other(e);
152 Poll::Ready(Err(e))
153 }
154 Err(e) => {
155 let e = std::io::Error::other(e);
156 Poll::Ready(Err(e))
157 }
158 }
159 }
160 InnerJoinHandle::Empty => {
161 Poll::Ready(Err(std::io::Error::from(std::io::ErrorKind::Other)))
162 }
163 }
164 }
165}
166
167#[derive(Clone)]
170pub struct AbortableJoinHandle<T> {
171 handle: Arc<InnerHandle<T>>,
172}
173
174impl<T> Debug for AbortableJoinHandle<T> {
175 fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result {
176 f.debug_struct("AbortableJoinHandle").finish()
177 }
178}
179
180impl<T> From<JoinHandle<T>> for AbortableJoinHandle<T> {
181 fn from(handle: JoinHandle<T>) -> Self {
182 AbortableJoinHandle {
183 handle: Arc::new(InnerHandle {
184 inner: parking_lot::Mutex::new(handle),
185 }),
186 }
187 }
188}
189
190impl<T> AbortableJoinHandle<T> {
191 pub fn empty() -> Self {
193 Self {
194 handle: Arc::new(InnerHandle {
195 inner: parking_lot::Mutex::new(JoinHandle::empty()),
196 }),
197 }
198 }
199}
200
201impl<T> AbortableJoinHandle<T> {
202 pub fn abort(&self) {
204 self.handle.inner.lock().abort();
205 }
206
207 pub fn is_finished(&self) -> bool {
209 self.handle.inner.lock().is_finished()
210 }
211
212 pub unsafe fn replace(&mut self, inner: AbortableJoinHandle<T>) {
219 let current_handle = &mut *self.handle.inner.lock();
220 let inner_handle = &mut *inner.handle.inner.lock();
221 current_handle.replace_in_place(inner_handle);
222 }
223}
224
225struct InnerHandle<T> {
226 pub inner: parking_lot::Mutex<JoinHandle<T>>,
227}
228
229impl<T> Drop for InnerHandle<T> {
230 fn drop(&mut self) {
231 self.inner.lock().abort();
232 }
233}
234
235impl<T> Future for AbortableJoinHandle<T> {
236 type Output = std::io::Result<T>;
237 fn poll(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> {
238 let inner = &mut *self.handle.inner.lock();
239 Pin::new(inner).poll(cx).map_err(std::io::Error::other)
240 }
241}
242
243pub struct CommunicationTask<T> {
245 _task_handle: AbortableJoinHandle<()>,
246 _channel_tx: futures::channel::mpsc::Sender<T>,
247}
248
249unsafe impl<T: Send> Send for CommunicationTask<T> {}
250unsafe impl<T: Send> Sync for CommunicationTask<T> {}
251
252impl<T> Clone for CommunicationTask<T> {
253 fn clone(&self) -> Self {
254 CommunicationTask {
255 _task_handle: self._task_handle.clone(),
256 _channel_tx: self._channel_tx.clone(),
257 }
258 }
259}
260
261impl<T> Debug for CommunicationTask<T> {
262 fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result {
263 f.debug_struct("CommunicationTask").finish()
264 }
265}
266
267impl<T> CommunicationTask<T>
268where
269 T: Send + Sync + 'static,
270{
271 pub async fn send(&mut self, data: T) -> std::io::Result<()> {
273 self._channel_tx
274 .send(data)
275 .await
276 .map_err(std::io::Error::other)
277 }
278
279 pub fn try_send(&self, data: T) -> std::io::Result<()> {
281 self._channel_tx
282 .clone()
283 .try_send(data)
284 .map_err(std::io::Error::other)
285 }
286
287 pub fn abort(mut self) {
289 self._channel_tx.close_channel();
290 self._task_handle.abort();
291 }
292
293 pub fn is_active(&self) -> bool {
295 !self._task_handle.is_finished() && !self._channel_tx.is_closed()
296 }
297}
298
299pub struct UnboundedCommunicationTask<T> {
301 _task_handle: AbortableJoinHandle<()>,
302 _channel_tx: futures::channel::mpsc::UnboundedSender<T>,
303}
304
305unsafe impl<T: Send> Send for UnboundedCommunicationTask<T> {}
306unsafe impl<T: Send> Sync for UnboundedCommunicationTask<T> {}
307
308impl<T> Clone for UnboundedCommunicationTask<T> {
309 fn clone(&self) -> Self {
310 UnboundedCommunicationTask {
311 _task_handle: self._task_handle.clone(),
312 _channel_tx: self._channel_tx.clone(),
313 }
314 }
315}
316
317impl<T> Debug for UnboundedCommunicationTask<T> {
318 fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result {
319 f.debug_struct("UnboundedCommunicationTask").finish()
320 }
321}
322
323impl<T> UnboundedCommunicationTask<T>
324where
325 T: Send + Sync + 'static,
326{
327 pub fn send(&mut self, data: T) -> std::io::Result<()> {
329 self._channel_tx
330 .unbounded_send(data)
331 .map_err(std::io::Error::other)
332 }
333
334 pub fn abort(self) {
336 self._channel_tx.close_channel();
337 self._task_handle.abort();
338 }
339
340 pub fn is_active(&self) -> bool {
342 !self._task_handle.is_finished() && !self._channel_tx.is_closed()
343 }
344}
345
346pub trait Executor {
347 fn spawn<F>(&self, future: F) -> JoinHandle<F::Output>
349 where
350 F: Future + Send + 'static,
351 F::Output: Send + 'static;
352
353 fn spawn_abortable<F>(&self, future: F) -> AbortableJoinHandle<F::Output>
359 where
360 F: Future + Send + 'static,
361 F::Output: Send + 'static,
362 {
363 let handle = self.spawn(future);
364 handle.into()
365 }
366
367 fn dispatch<F>(&self, future: F)
370 where
371 F: Future + Send + 'static,
372 F::Output: Send + 'static,
373 {
374 self.spawn(future);
375 }
376
377 fn spawn_coroutine<T, F, Fut>(&self, mut f: F) -> CommunicationTask<T>
381 where
382 F: FnMut(Receiver<T>) -> Fut,
383 Fut: Future<Output = ()> + Send + 'static,
384 {
385 let (tx, rx) = futures::channel::mpsc::channel(1);
386 let fut = f(rx);
387 let _task_handle = self.spawn_abortable(fut);
388 CommunicationTask {
389 _task_handle,
390 _channel_tx: tx,
391 }
392 }
393
394 fn spawn_coroutine_with_context<T, F, C, Fut>(
398 &self,
399 context: C,
400 mut f: F,
401 ) -> CommunicationTask<T>
402 where
403 F: FnMut(C, Receiver<T>) -> Fut,
404 Fut: Future<Output = ()> + Send + 'static,
405 {
406 let (tx, rx) = futures::channel::mpsc::channel(1);
407 let fut = f(context, rx);
408 let _task_handle = self.spawn_abortable(fut);
409 CommunicationTask {
410 _task_handle,
411 _channel_tx: tx,
412 }
413 }
414
415 fn spawn_unbounded_coroutine<T, F, Fut>(&self, mut f: F) -> UnboundedCommunicationTask<T>
419 where
420 F: FnMut(UnboundedReceiver<T>) -> Fut,
421 Fut: Future<Output = ()> + Send + 'static,
422 {
423 let (tx, rx) = futures::channel::mpsc::unbounded();
424 let fut = f(rx);
425 let _task_handle = self.spawn_abortable(fut);
426 UnboundedCommunicationTask {
427 _task_handle,
428 _channel_tx: tx,
429 }
430 }
431
432 fn spawn_unbounded_coroutine_with_context<T, F, C, Fut>(
436 &self,
437 context: C,
438 mut f: F,
439 ) -> UnboundedCommunicationTask<T>
440 where
441 F: FnMut(C, UnboundedReceiver<T>) -> Fut,
442 Fut: Future<Output = ()> + Send + 'static,
443 {
444 let (tx, rx) = futures::channel::mpsc::unbounded();
445 let fut = f(context, rx);
446 let _task_handle = self.spawn_abortable(fut);
447 UnboundedCommunicationTask {
448 _task_handle,
449 _channel_tx: tx,
450 }
451 }
452}
453
454#[cfg(test)]
455mod tests {
456 use crate::{Executor, InnerJoinHandle, JoinHandle};
457 use futures::future::AbortHandle;
458 use std::future::Future;
459
460 async fn task(tx: futures::channel::oneshot::Sender<()>) {
461 futures_timer::Delay::new(std::time::Duration::from_secs(5)).await;
462 let _ = tx.send(());
463 unreachable!();
464 }
465
466 #[test]
467 fn custom_abortable_task() {
468 use futures::future::Abortable;
469 struct FuturesExecutor {
470 pool: futures::executor::ThreadPool,
471 }
472
473 impl Default for FuturesExecutor {
474 fn default() -> Self {
475 Self {
476 pool: futures::executor::ThreadPool::new().unwrap(),
477 }
478 }
479 }
480
481 impl Executor for FuturesExecutor {
482 fn spawn<F>(&self, future: F) -> JoinHandle<F::Output>
483 where
484 F: Future + Send + 'static,
485 F::Output: Send + 'static,
486 {
487 let (abort_handle, abort_registration) = AbortHandle::new_pair();
488 let future = Abortable::new(future, abort_registration);
489 let (tx, rx) = futures::channel::oneshot::channel();
490 let fut = async {
491 let val = future.await;
492 let _ = tx.send(val);
493 };
494
495 self.pool.spawn_ok(fut);
496 let inner = InnerJoinHandle::CustomHandle {
497 inner: Some(rx),
498 handle: abort_handle,
499 };
500
501 JoinHandle { inner }
502 }
503 }
504
505 futures::executor::block_on(async move {
506 let executor = FuturesExecutor::default();
507
508 let (tx, rx) = futures::channel::oneshot::channel::<()>();
509 let handle = executor.spawn_abortable(task(tx));
510 drop(handle);
511 let result = rx.await;
512 assert!(result.is_err());
513 });
514 }
515}