shuttle/future/
mod.rs

1//! Shuttle's implementation of an async executor, roughly equivalent to [`futures::executor`].
2//!
3//! The [spawn] method spawns a new asynchronous task that the executor will run to completion. The
4//! [block_on] method blocks the current thread on the completion of a future.
5//!
6//! [`futures::executor`]: https://docs.rs/futures/0.3.13/futures/executor/index.html
7
8use crate::runtime::execution::ExecutionState;
9use crate::runtime::task::TaskId;
10use crate::runtime::thread;
11use std::error::Error;
12use std::fmt::{Display, Formatter};
13use std::future::Future;
14use std::pin::Pin;
15use std::result::Result;
16use std::sync::Arc;
17use std::task::{Context, Poll, Waker};
18
19pub mod batch_semaphore;
20
21fn spawn_inner<F>(fut: F) -> JoinHandle<F::Output>
22where
23    F: Future + 'static,
24    F::Output: 'static,
25{
26    let stack_size = ExecutionState::with(|s| s.config.stack_size);
27    let inner = Arc::new(std::sync::Mutex::new(JoinHandleInner::default()));
28    let task_id = ExecutionState::spawn_future(Wrapper::new(fut, inner.clone()), stack_size, None);
29
30    thread::switch();
31
32    JoinHandle { task_id, inner }
33}
34
35/// Spawn a new async task that the executor will run to completion.
36pub fn spawn<F>(fut: F) -> JoinHandle<F::Output>
37where
38    F: Future + Send + 'static,
39    F::Output: Send + 'static,
40{
41    spawn_inner(fut)
42}
43
44/// Spawn a new async task that the executor will run to completion.
45/// This is just `spawn` without the `Send` bound, and it mirrors `spawn_local` from Tokio.
46pub fn spawn_local<F>(fut: F) -> JoinHandle<F::Output>
47where
48    F: Future + 'static,
49    F::Output: 'static,
50{
51    spawn_inner(fut)
52}
53
54/// An owned permission to abort a spawned task, without awaiting its completion.
55#[derive(Debug, Clone)]
56pub struct AbortHandle {
57    task_id: TaskId,
58}
59
60impl AbortHandle {
61    /// Abort the task associated with the handle.
62    pub fn abort(&self) {
63        ExecutionState::try_with(|state| {
64            if !state.is_finished() {
65                let task = state.get_mut(self.task_id);
66                task.abort();
67            }
68        });
69    }
70
71    /// Returns `true` if this task is finished, otherwise returns `false`.
72    ///
73    /// ## Panics
74    /// Panics if called outside of shuttle context, i.e. if there is no execution context.
75    pub fn is_finished(&self) -> bool {
76        ExecutionState::with(|state| {
77            let task = state.get(self.task_id);
78            task.finished()
79        })
80    }
81}
82
83unsafe impl Send for AbortHandle {}
84unsafe impl Sync for AbortHandle {}
85
86/// An owned permission to join on an async task (await its termination).
87#[derive(Debug)]
88pub struct JoinHandle<T> {
89    task_id: TaskId,
90    inner: std::sync::Arc<std::sync::Mutex<JoinHandleInner<T>>>,
91}
92
93#[derive(Debug)]
94struct JoinHandleInner<T> {
95    result: Option<Result<T, JoinError>>,
96    waker: Option<Waker>,
97}
98
99impl<T> Default for JoinHandleInner<T> {
100    fn default() -> Self {
101        JoinHandleInner {
102            result: None,
103            waker: None,
104        }
105    }
106}
107
108impl<T> JoinHandle<T> {
109    /// Abort the task associated with the handle.
110    pub fn abort(&self) {
111        ExecutionState::try_with(|state| {
112            if !state.is_finished() {
113                let task = state.get_mut(self.task_id);
114                task.abort();
115            }
116        });
117    }
118
119    /// Returns `true` if this task is finished, otherwise returns `false`.
120    ///
121    /// ## Panics
122    /// Panics if called outside of shuttle context, i.e. if there is no execution context.
123    pub fn is_finished(&self) -> bool {
124        ExecutionState::with(|state| {
125            let task = state.get(self.task_id);
126            task.finished()
127        })
128    }
129
130    /// Returns a new `AbortHandle` that can be used to remotely abort this task.
131    pub fn abort_handle(&self) -> AbortHandle {
132        AbortHandle { task_id: self.task_id }
133    }
134}
135
136// TODO: need to work out all the error cases here
137/// Task failed to execute to completion.
138#[derive(Debug)]
139pub enum JoinError {
140    /// Task was aborted
141    Cancelled,
142}
143
144impl Display for JoinError {
145    fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result {
146        match self {
147            JoinError::Cancelled => write!(f, "task was cancelled"),
148        }
149    }
150}
151
152impl Error for JoinError {}
153
154impl<T> Drop for JoinHandle<T> {
155    fn drop(&mut self) {
156        self.abort();
157    }
158}
159
160impl<T> Future for JoinHandle<T> {
161    type Output = Result<T, JoinError>;
162
163    fn poll(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> {
164        let mut lock = self.inner.lock().unwrap();
165        if let Some(result) = lock.result.take() {
166            Poll::Ready(result)
167        } else {
168            lock.waker = Some(cx.waker().clone());
169            Poll::Pending
170        }
171    }
172}
173
174// We wrap a task returning a value inside a wrapper task that returns (). The wrapper
175// contains a mutex-wrapped field that stores the value and the waker for the task
176// waiting on the join handle. When `poll` returns `Poll::Ready`, the `Wrapper` stores
177// the result in the `result` field and wakes the `waker`.
178struct Wrapper<F: Future> {
179    future: Pin<Box<F>>,
180    inner: std::sync::Arc<std::sync::Mutex<JoinHandleInner<F::Output>>>,
181}
182
183impl<F> Wrapper<F>
184where
185    F: Future + 'static,
186    F::Output: 'static,
187{
188    fn new(future: F, inner: std::sync::Arc<std::sync::Mutex<JoinHandleInner<F::Output>>>) -> Self {
189        Self {
190            future: Box::pin(future),
191            inner,
192        }
193    }
194}
195
196impl<F> Future for Wrapper<F>
197where
198    F: Future + 'static,
199    F::Output: 'static,
200{
201    type Output = ();
202
203    fn poll(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> {
204        match self.future.as_mut().poll(cx) {
205            Poll::Ready(result) => {
206                // If we've finished execution already (this task was detached), don't clean up. We
207                // can't access the state any more to destroy thread locals, and don't want to run
208                // any more wakers (which will be no-ops anyway).
209                if ExecutionState::try_with(|state| state.is_finished()).unwrap_or(true) {
210                    return Poll::Ready(());
211                }
212
213                // Run thread-local destructors before publishing the result.
214                // See `pop_local` for details on why this loop looks this slightly funky way.
215                // TODO: thread locals and futures don't mix well right now. each task gets its own
216                //       thread local storage, but real async code might know that its executor is
217                //       single-threaded and so use TLS to share objects across all tasks.
218                while let Some(local) = ExecutionState::with(|state| state.current_mut().pop_local()) {
219                    drop(local);
220                }
221
222                let mut lock = self.inner.lock().unwrap();
223                lock.result = Some(Ok(result));
224
225                if let Some(waker) = lock.waker.take() {
226                    waker.wake();
227                }
228
229                Poll::Ready(())
230            }
231            Poll::Pending => Poll::Pending,
232        }
233    }
234}
235
236/// Run a future to completion on the current thread.
237pub fn block_on<F: Future>(future: F) -> F::Output {
238    let mut future = Box::pin(future);
239    let waker = ExecutionState::with(|state| state.current_mut().waker());
240    let cx = &mut Context::from_waker(&waker);
241
242    loop {
243        match future.as_mut().poll(cx) {
244            Poll::Ready(result) => break result,
245            Poll::Pending => {
246                ExecutionState::with(|state| state.current_mut().sleep_unless_woken());
247            }
248        }
249
250        thread::switch();
251    }
252}
253
254/// Yields execution back to the scheduler.
255///
256/// Borrowed from the Tokio implementation.
257pub async fn yield_now() {
258    /// Yield implementation
259    struct YieldNow {
260        yielded: bool,
261    }
262
263    impl Future for YieldNow {
264        type Output = ();
265
266        fn poll(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<()> {
267            if self.yielded {
268                return Poll::Ready(());
269            }
270
271            self.yielded = true;
272            cx.waker().wake_by_ref();
273            ExecutionState::request_yield();
274            Poll::Pending
275        }
276    }
277
278    YieldNow { yielded: false }.await
279}