1
  2
  3
  4
  5
  6
  7
  8
  9
 10
 11
 12
 13
 14
 15
 16
 17
 18
 19
 20
 21
 22
 23
 24
 25
 26
 27
 28
 29
 30
 31
 32
 33
 34
 35
 36
 37
 38
 39
 40
 41
 42
 43
 44
 45
 46
 47
 48
 49
 50
 51
 52
 53
 54
 55
 56
 57
 58
 59
 60
 61
 62
 63
 64
 65
 66
 67
 68
 69
 70
 71
 72
 73
 74
 75
 76
 77
 78
 79
 80
 81
 82
 83
 84
 85
 86
 87
 88
 89
 90
 91
 92
 93
 94
 95
 96
 97
 98
 99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
//! async-wormhole allows you to call `.await` async calls across non-async functions, like extern "C" or JIT
//! generated code.
//!
//! ## Motivation
//!
//! Sometimes, when running inside an async environment you need to call into JIT generated code (e.g. wasm)
//! and .await from there. Because the JIT code is not available at compile time, the Rust compiler can't
//! do their "create a state machine" magic. In the end you can't have `.await` statements in non-async
//! functions.
//!
//! This library creates a special stack for executing the JIT code, so it's possible to suspend it at any
//! point of the execution. Once you pass it a closure inside [AsyncWormhole::new](struct.AsyncWormhole.html#method.new)
//! you will get back a future that you can `.await` on. The passed in closure is going to be executed on a
//! new stack.
//!
//! Sometimes you also need to preserve thread local storage as the code inside the closure expects it to stay
//! the same, but the actual execution can be moved between threads. There is a
//! [proof of concept API](struct.AsyncWormhole.html#method.preserve_tls)
//! to allow you to move your thread local storage with the closure across threads.
//!
//! ## Example
//!
//! ```rust
//! use async_wormhole::{AsyncWormhole, AsyncYielder};
//! use switcheroo::stack::*;
//!
//! // non-async function
//! extern "C" fn non_async(mut yielder: AsyncYielder<u32>) -> u32 {
//! 	// Suspend the runtime until async value is ready.
//! 	// Can contain .await calls.
//!     yielder.async_suspend(async { 42 })
//! }
//!
//! fn main() {
//!     let stack = EightMbStack::new().unwrap();
//!     let task = AsyncWormhole::<_, _, ()>::new(stack, |yielder| {
//!         let result = non_async(yielder);
//!         assert_eq!(result, 42);
//!         64
//!     })
//!     .unwrap();
//!
//!     let outside = futures::executor::block_on(task);
//!     assert_eq!(outside.unwrap(), 64);
//! }
//! ```

pub mod pool;

use switcheroo::stack;
use switcheroo::Generator;
use switcheroo::Yielder;

use std::ptr;
use std::future::Future;
use std::io::Error;
use std::pin::Pin;
use std::task::{ Context, Poll, Waker };
use std::thread::LocalKey;
use std::cell::Cell;

/// This structure holds one thread local variable that is preserved across context switches.
/// This gives code that use thread local variables inside the closure the impression that they are
/// running on the same thread they started even if they have been moved to a different tread.
/// TODO: This code is currently higly specific to WASMTIME's signal handler TLS and could be
/// generalized. The only issue is that we can't have `Default` traits on pointers and we need to
/// get rid of *const TLS in Wasmtime.
struct ThreadLocal<TLS: 'static> {
    ptr: &'static LocalKey<Cell<*const TLS>>,
    value: *const TLS,
}

pub struct AsyncWormhole<'a, Stack: stack::Stack, Output, TLS: 'static> {
    generator: Cell<Generator<'a, Waker, Option<Output>, Stack>>,
    thread_local: Option<ThreadLocal<TLS>>,
}

unsafe impl<Stack: stack::Stack, Output, TLS> Send for AsyncWormhole<'_, Stack, Output, TLS> {}

impl<'a, Stack: stack::Stack, Output, TLS> AsyncWormhole<'a, Stack, Output, TLS> {
    /// Takes a stack and a closure and returns an `impl Future` that can be awaited on.
    pub fn new<F>(stack: Stack, f: F) -> Result<Self, Error>
    where
        // TODO: This needs to be Send, but because Wasmtime's strucutres are not Send for now I don't
        // enforce it on an API level. Accroding to
        // https://github.com/bytecodealliance/wasmtime/issues/793#issuecomment-692740254
        // it is safe to move everything connected to a Store to a different thread all at once, but this
        // is impossible to express with the type system.
        F: FnOnce(AsyncYielder<Output>) -> Output + 'a,
    {
        let generator = Generator::new(stack, |yielder, waker| {
            let async_yielder = AsyncYielder::new(yielder, waker);
            yielder.suspend(Some(f(async_yielder)));
        });

        Ok(Self { generator: Cell::new(generator), thread_local: None})
    }

    /// Takes a reference to the to be preserved TLS variable.
    pub fn preserve_tls(&mut self, tls: &'static LocalKey<Cell<*const TLS>>) {
        self.thread_local = Some(ThreadLocal {
            ptr: tls,
            value: ptr::null(),
        });
    }

    /// Get the stack from the internal generator.
    pub fn stack(self) -> Stack {
        self.generator.into_inner().stack()
    }
}

impl<'a, Stack: stack::Stack + Unpin, Output, TLS: Unpin> Future for AsyncWormhole<'a, Stack, Output, TLS> {
    type Output = Option<Output>;

    fn poll(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> {
        // If we saved a TLS value, put it back in.
        // If this is the first `poll` it will overwrite the existing TLS value with null.
        match &self.thread_local {
            None => {},
            Some(tls) => {
                tls.ptr.with(|v| v.set(tls.value))
            }
        };

        match self.generator.get_mut().resume(cx.waker().clone()) {
            // If we call the future after it completed it will always return Poll::Pending.
            // But polling a completed future is either way undefined behaviour.
            None | Some(None) => {
                // Preserve any TLS value if set
                match self.thread_local.take() {
                    None => {},
                    Some(mut tls) => tls.ptr.with(|v|
                        tls.value = v.get()
                    )
                };
                Poll::Pending
            },
            Some(out) => Poll::Ready(out),
        }
    }
}

#[derive(Clone)]
pub struct AsyncYielder<'a, Output> {
    yielder: &'a Yielder<Waker, Option<Output>>,
    waker: Waker,
}

impl<'a, Output> AsyncYielder<'a, Output> {
    pub(crate) fn new(yielder: &'a Yielder<Waker, Option<Output>>, waker: Waker) -> Self {
        Self { yielder, waker }
    }

    /// Takes an `impl Future` and awaits it, returning the value from it once ready.
    pub fn async_suspend<Fut, R>(&mut self, future: Fut) -> R
    where
        Fut: Future<Output = R>,
    {
        pin_utils::pin_mut!(future);
        loop {
            let cx = &mut Context::from_waker(&mut self.waker);
            self.waker = match future.as_mut().poll(cx) {
                Poll::Pending => self.yielder.suspend(None),
                Poll::Ready(result) => return result,
            };
        }
    }
}