1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
#![allow(unsafe_code)] // future runtime require unsafe code to be implemented
#![allow(clippy::undocumented_unsafe_blocks)] // TODO: remove and comment blocks instead

use futures::channel::oneshot::channel;
use futures::channel::oneshot::Receiver;
use futures::FutureExt;
use std::future::Future;
use std::pin::Pin;
use std::task::Context;
use std::task::Poll;

/// Simple single-threaded futures runtime for Wasm
///
/// This is only available if building with the `with_async` feature flag:
/// ```toml
/// [dependencies]
/// ark-api = { version = "*", features = ["with_async"] }
/// ```
///
/// ## Usage
///
/// To use this runtime to run async code call the `spawn` method on a future and
/// then call `poll` on every frame.
///
/// If you are writing an applet you may want to create a new `Runtime` in your
/// `new` method and call `poll` on every frame in the `update` method.
///
/// ```rust
/// # use ark_module::Runtime;
/// # use futures::{Future, channel::oneshot::Receiver};
/// # pub trait Applet {
/// #     fn new() -> Self;
/// #     fn update(&mut self);
/// # }
/// # async fn some_async_function() -> String { "Hello".to_string() }
///
/// struct Module {
///     runtime: Runtime,
///     channel: Receiver<String>,
/// }
///
/// impl Applet for Module {
///     fn new() -> Self {
///         let mut runtime = Runtime::new();
///
///         // Call an async function to get a future
///         let future = some_async_function();
///         // Run the future on the runtime, getting back a channel for the final result
///         let channel = runtime.spawn(future);
///
///         Self { channel, runtime }
///     }
///
///     fn update(&mut self) {
///         // Poll the runtime to run the future
///         self.runtime.poll();
///
///         // Check the channel to see if the future has resolved yet
///         match self.channel.try_recv() {
///             Ok(Some(value)) => println!("The future resolved returning {}", value),
///             _ => println!("Still waiting for the future"),
///         };
///     }
/// }
/// ```
///
/// Here we are adding the future to the runtime with the `spawn` method in the
/// `new` method, but futures can be added to the runtime at any point in your
/// Ark module.
///
#[derive(Default)]
pub struct Runtime {
    tasks: Vec<Pin<Box<dyn Future<Output = ()> + 'static + Send>>>,
    // Reuse allocations, instead of allocating for every poll. This can be removed when
    // drain_filter is on stable.
    allocation: Vec<Pin<Box<dyn Future<Output = ()> + 'static + Send>>>,
}

// Default implementations for a waker which is required for polling futures
pub(crate) mod waker {
    use std::task::RawWaker;
    use std::task::RawWakerVTable;
    use std::task::Waker;
    unsafe fn clone(_: *const ()) -> RawWaker {
        RAW_WAKER
    }
    unsafe fn wake(_: *const ()) {}
    unsafe fn wake_by_ref(_: *const ()) {}
    unsafe fn drop(_: *const ()) {}
    const VTABLE: &RawWakerVTable = &RawWakerVTable::new(clone, wake, wake_by_ref, drop);
    const RAW_WAKER: RawWaker = RawWaker::new(std::ptr::null(), VTABLE);

    pub fn waker() -> Waker {
        unsafe { Waker::from_raw(RAW_WAKER) }
    }
}
impl Runtime {
    /// Creates new runtime
    pub fn new() -> Self {
        Default::default()
    }

    /// Poll the runtime to update futures.
    /// This should be called on every frame until all futures managed by the
    /// runtime have resolved.
    pub fn poll(&mut self) {
        let waker = waker::waker();
        let mut context = Context::from_waker(&waker);
        let unfinished_futures = self.tasks.drain(..).filter_map(|mut future| {
            match future.as_mut().poll(&mut context) {
                Poll::Ready(_) => None,
                Poll::Pending => Some(future),
            }
        });
        self.allocation.extend(unfinished_futures);
        std::mem::swap(&mut self.tasks, &mut self.allocation);
    }

    /// Spawn a future, returning a channel that will receive the result of the
    /// future once it resolves.
    ///
    /// Ensure to use `.try_recv` rather than `.recv` when checking if the
    /// channel contains the result of the future as `.recv` will block until the
    /// future resolves, resulting in poor player experience and performance.
    pub fn spawn<T: Send + 'static>(
        &mut self,
        future: impl Future<Output = T> + 'static + Send,
    ) -> Receiver<T> {
        let (tx, rx) = channel();
        let task = async move {
            let r = future.await;
            // We ignore the error because the receiver has been dropped.
            let _ = tx.send(r);
        };
        self.tasks.push(task.boxed());
        rx
    }

    /// Drives a future to completion and returns the result.
    ///
    /// Note this function is asynchronous and will block until the future has
    /// resolved. Using this function in your game's update loop will likely
    /// result in poor player experience and performance.
    pub fn block_on<T>(&mut self, mut future: impl Future<Output = T>) -> T {
        let waker = waker::waker();
        let mut context = Context::from_waker(&waker);

        // We pin the future to the stack, so that it doesn't have to be 'static or `Send`.
        let mut pinned_future = unsafe { Pin::new_unchecked(&mut future) };

        // [TODO] Avoid busy loop here.
        loop {
            if let Poll::Ready(result) = pinned_future.as_mut().poll(&mut context) {
                return result;
            }
            // We still have to poll other futures, as this future might spawn futures.
            self.poll();
        }
    }
}