1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
// Copyright 2018-2023 the Deno authors. All rights reserved. MIT license.
use core::pin::Pin;
use core::task::Context;
use core::task::Poll;
use std::future::Future;
use std::marker::PhantomData;
use tokio::runtime::Handle;
use tokio::runtime::RuntimeFlavor;

/// Equivalent to [`tokio::task::JoinHandle`].
#[repr(transparent)]
pub struct JoinHandle<R> {
    handle: tokio::task::JoinHandle<MaskResultAsSend<R>>,
    _r: PhantomData<R>,
}

impl<R> JoinHandle<R> {
    /// Equivalent to [`tokio::task::JoinHandle::abort`].
    pub fn abort(&self) {
        self.handle.abort()
    }
}

impl<R> Future for JoinHandle<R> {
    type Output = Result<R, tokio::task::JoinError>;

    fn poll(
        self: Pin<&mut Self>,
        cx: &mut std::task::Context<'_>,
    ) -> std::task::Poll<Self::Output> {
        // SAFETY: We are sure that handle is valid here
        unsafe {
            let me: &mut Self = Pin::into_inner_unchecked(self);
            let handle = Pin::new_unchecked(&mut me.handle);
            match handle.poll(cx) {
                Poll::Pending => Poll::Pending,
                Poll::Ready(Ok(r)) => Poll::Ready(Ok(r.into_inner())),
                Poll::Ready(Err(e)) => Poll::Ready(Err(e)),
            }
        }
    }
}

/// Equivalent to [`tokio::task::spawn`], but does not require the future to be [`Send`]. Must only be
/// used on a [`RuntimeFlavor::CurrentThread`] executor, though this is only checked when running with
/// debug assertions.
#[inline(always)]
pub fn spawn<F: Future<Output = R> + 'static, R: 'static>(f: F) -> JoinHandle<R> {
    debug_assert!(Handle::current().runtime_flavor() == RuntimeFlavor::CurrentThread);
    // SAFETY: we know this is a current-thread executor
    let future = unsafe { MaskFutureAsSend::new(f) };
    JoinHandle {
        handle: tokio::task::spawn(future),
        _r: Default::default(),
    }
}

/// Equivalent to [`tokio::task::spawn_blocking`]. Currently a thin wrapper around the tokio API, but this
/// may change in the future.
#[inline(always)]
pub fn spawn_blocking<F: (FnOnce() -> R) + Send + 'static, R: Send + 'static>(
    f: F,
) -> JoinHandle<R> {
    let handle = tokio::task::spawn_blocking(|| MaskResultAsSend { result: f() });
    JoinHandle {
        handle,
        _r: Default::default(),
    }
}

#[repr(transparent)]
#[doc(hidden)]
pub struct MaskResultAsSend<R> {
    result: R,
}

/// SAFETY: We ensure that Send bounds are only faked when tokio is running on a current-thread executor
unsafe impl<R> Send for MaskResultAsSend<R> {}

impl<R> MaskResultAsSend<R> {
    #[inline(always)]
    pub fn into_inner(self) -> R {
        self.result
    }
}

#[repr(transparent)]
pub struct MaskFutureAsSend<F> {
    future: F,
}

impl<F> MaskFutureAsSend<F> {
    /// Mark a non-`Send` future as `Send`. This is a trick to be able to use
    /// `tokio::spawn()` (which requires `Send` futures) in a current thread
    /// runtime.
    ///
    /// # Safety
    ///
    /// You must ensure that the future is actually used on the same
    /// thread, ie. always use current thread runtime flavor from Tokio.
    #[inline(always)]
    pub unsafe fn new(future: F) -> Self {
        Self { future }
    }
}

// SAFETY: we are cheating here - this struct is NOT really Send,
// but we need to mark it Send so that we can use `spawn()` in Tokio.
unsafe impl<F> Send for MaskFutureAsSend<F> {}

impl<F: Future> Future for MaskFutureAsSend<F> {
    type Output = MaskResultAsSend<F::Output>;

    fn poll(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<MaskResultAsSend<F::Output>> {
        // SAFETY: We are sure that future is valid here
        unsafe {
            let me: &mut MaskFutureAsSend<F> = Pin::into_inner_unchecked(self);
            let future = Pin::new_unchecked(&mut me.future);
            match future.poll(cx) {
                Poll::Pending => Poll::Pending,
                Poll::Ready(result) => Poll::Ready(MaskResultAsSend { result }),
            }
        }
    }
}