1
  2
  3
  4
  5
  6
  7
  8
  9
 10
 11
 12
 13
 14
 15
 16
 17
 18
 19
 20
 21
 22
 23
 24
 25
 26
 27
 28
 29
 30
 31
 32
 33
 34
 35
 36
 37
 38
 39
 40
 41
 42
 43
 44
 45
 46
 47
 48
 49
 50
 51
 52
 53
 54
 55
 56
 57
 58
 59
 60
 61
 62
 63
 64
 65
 66
 67
 68
 69
 70
 71
 72
 73
 74
 75
 76
 77
 78
 79
 80
 81
 82
 83
 84
 85
 86
 87
 88
 89
 90
 91
 92
 93
 94
 95
 96
 97
 98
 99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
#![feature(min_const_generics)]

//! async-wormhole allows you to call `.await` async calls across non-async functions, like extern "C" or JIT
//! generated code.
//!
//! ## Motivation
//!
//! Sometimes, when running inside an async environment you need to call into JIT generated code (e.g. wasm)
//! and .await from there. Because the JIT code is not available at compile time, the Rust compiler can't
//! do their "create a state machine" magic. In the end you can't have `.await` statements in non-async
//! functions.
//!
//! This library creates a special stack for executing the JIT code, so it's possible to suspend it at any
//! point of the execution. Once you pass it a closure inside [AsyncWormhole::new](struct.AsyncWormhole.html#method.new)
//! you will get back a future that you can `.await` on. The passed in closure is going to be executed on a
//! new stack.
//!
//! Sometimes you also need to preserve thread local storage as the code inside the closure expects it to stay
//! the same, but the actual execution can be moved between threads. There is a
//! [proof of concept API](struct.AsyncWormhole.html#method.preserve_tls)
//! to allow you to move your thread local storage with the closure across threads.
//!
//! ## Example
//!
//! ```rust
//! use async_wormhole::{AsyncWormhole, AsyncYielder};
//! use switcheroo::stack::*;
//!
//! // non-async function
//! extern "C" fn non_async(mut yielder: AsyncYielder<u32>) -> u32 {
//! 	// Suspend the runtime until async value is ready.
//! 	// Can contain .await calls.
//!     yielder.async_suspend(async { 42 })
//! }
//!
//! fn main() {
//!     let stack = EightMbStack::new().unwrap();
//!     let task = AsyncWormhole::new(stack, |yielder| {
//!         let result = non_async(yielder);
//!         assert_eq!(result, 42);
//!         64
//!     })
//!     .unwrap();
//!
//!     let outside = futures::executor::block_on(task);
//!     assert_eq!(outside.unwrap(), 64);
//! }
//! ```

pub mod pool;

use switcheroo::stack;
use switcheroo::Generator;
use switcheroo::Yielder;

use std::cell::Cell;
use std::convert::TryInto;
use std::future::Future;
use std::io::Error;
use std::pin::Pin;
use std::task::{Context, Poll, Waker};
use std::thread::LocalKey;

// This structure holds one thread local variable that is preserved across context switches.
// This gives code that use thread local variables inside the closure the impression that they are
// running on the same thread they started even if they have been moved to a different one.
struct ThreadLocal<TLS: 'static> {
    reference: &'static LocalKey<Cell<*const TLS>>,
    value: *const TLS,
}

impl<TLS> Copy for ThreadLocal<TLS> {}
impl<TLS> Clone for ThreadLocal<TLS> {
    fn clone(&self) -> Self {
        ThreadLocal {
            reference: self.reference,
            value: self.value,
        }
    }
}

/// AsyncWormhole captures a stack and a closure. It also implements Future and can be awaited on.
pub struct AsyncWormhole<'a, Stack: stack::Stack, Output, TLS: 'static, const TLS_COUNT: usize> {
    generator: Cell<Generator<'a, Waker, Option<Output>, Stack>>,
    preserved_thread_locals: [ThreadLocal<TLS>; TLS_COUNT],
}

unsafe impl<Stack: stack::Stack, Output, TLS, const TLS_COUNT: usize> Send
    for AsyncWormhole<'_, Stack, Output, TLS, TLS_COUNT>
{
}

impl<'a, Stack: stack::Stack, Output> AsyncWormhole<'a, Stack, Output, (), 0> {
    /// Returns a new AsyncWormhole, using the passed `stack` to execute the closure `f` on.
    /// The closure will not be executed right away, only if you pass AsyncWormhole to an
    /// async executor (.await on it).
    pub fn new<F>(stack: Stack, f: F) -> Result<Self, Error>
    where
        F: FnOnce(AsyncYielder<Output>) -> Output + 'a,
    {
        AsyncWormhole::new_with_tls([], stack, f)
    }
}

impl<'a, Stack: stack::Stack, Output, TLS, const TLS_COUNT: usize>
    AsyncWormhole<'a, Stack, Output, TLS, TLS_COUNT>
{
    /// Similar to `new`, but allows you to capture thread local variables inside the closure.
    /// During the execution of the future an async executor can move the closure `f` between
    /// threads. From the perspective of the code inside the closure `f` the thread local
    /// variables will be moving with it from thread to thread.
    ///
    /// ### Safety
    ///
    /// If the thread local variable is only set and used inside of the `f` closure than it's safe
    ///  to use it. Outside of the closure the content of it will be unpredictable.
    pub fn new_with_tls<F>(
        tls_refs: [&'static LocalKey<Cell<*const TLS>>; TLS_COUNT],
        stack: Stack,
        f: F,
    ) -> Result<Self, Error>
    where
        // TODO: This needs to be Send, but because Wasmtime's strucutres are not Send for now I don't
        // enforce it on an API level. Accroding to
        // https://github.com/bytecodealliance/wasmtime/issues/793#issuecomment-692740254
        // it is safe to move everything connected to a Store to a different thread all at once, but this
        // is impossible to express with the type system.
        F: FnOnce(AsyncYielder<Output>) -> Output + 'a,
    {
        let generator = Generator::new(stack, |yielder, waker| {
            let async_yielder = AsyncYielder::new(yielder, waker);
            yielder.suspend(Some(f(async_yielder)));
        });

        let preserved_thread_locals = tls_refs
            .iter()
            .map(|tls_ref| ThreadLocal {
                reference: tls_ref,
                value: tls_ref.with(|v| v.get()),
            })
            .collect::<Vec<ThreadLocal<TLS>>>()
            .as_slice()
            .try_into()
            .unwrap();

        Ok(Self {
            generator: Cell::new(generator),
            preserved_thread_locals,
        })
    }

    /// Get the stack from the internal generator.
    pub fn stack(self) -> Stack {
        self.generator.into_inner().stack()
    }
}

impl<'a, Stack: stack::Stack + Unpin, Output, TLS: Unpin, const TLS_COUNT: usize> Future
    for AsyncWormhole<'a, Stack, Output, TLS, TLS_COUNT>
{
    type Output = Option<Output>;

    fn poll(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> {
        // Restore thread local values when re-entering execution
        for tls in self.preserved_thread_locals.iter() {
            tls.reference.with(|v| v.set(tls.value));
        }

        match self.generator.get_mut().resume(cx.waker().clone()) {
            // If we call the future after it completed it will always return Poll::Pending.
            // But polling a completed future is either way undefined behaviour.
            None | Some(None) => {
                // Preserve all thread local values
                for tls in self.preserved_thread_locals.iter_mut() {
                    tls.reference.with(|v| tls.value = v.get());
                }
                Poll::Pending
            }
            Some(out) => Poll::Ready(out),
        }
    }
}

#[derive(Clone)]
pub struct AsyncYielder<'a, Output> {
    yielder: &'a Yielder<Waker, Option<Output>>,
    waker: Waker,
}

impl<'a, Output> AsyncYielder<'a, Output> {
    pub(crate) fn new(yielder: &'a Yielder<Waker, Option<Output>>, waker: Waker) -> Self {
        Self { yielder, waker }
    }

    /// Takes an `impl Future` and awaits it, returning the value from it once ready.
    pub fn async_suspend<Fut, R>(&mut self, future: Fut) -> R
    where
        Fut: Future<Output = R>,
    {
        pin_utils::pin_mut!(future);
        loop {
            let cx = &mut Context::from_waker(&mut self.waker);
            self.waker = match future.as_mut().poll(cx) {
                Poll::Pending => self.yielder.suspend(None),
                Poll::Ready(result) => return result,
            };
        }
    }
}