1
  2
  3
  4
  5
  6
  7
  8
  9
 10
 11
 12
 13
 14
 15
 16
 17
 18
 19
 20
 21
 22
 23
 24
 25
 26
 27
 28
 29
 30
 31
 32
 33
 34
 35
 36
 37
 38
 39
 40
 41
 42
 43
 44
 45
 46
 47
 48
 49
 50
 51
 52
 53
 54
 55
 56
 57
 58
 59
 60
 61
 62
 63
 64
 65
 66
 67
 68
 69
 70
 71
 72
 73
 74
 75
 76
 77
 78
 79
 80
 81
 82
 83
 84
 85
 86
 87
 88
 89
 90
 91
 92
 93
 94
 95
 96
 97
 98
 99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
use std::future::Future;
use std::pin::Pin;
use std::sync::atomic::{AtomicPtr, Ordering};
use std::sync::Arc;
use std::task::Context;

use futures_task::noop_waker;

/// The `RebindTo` trait defines a type level function that allows you convert a type that holds
/// references of lifetime `'a` to a type that holds references of lifetime `'b`.
///
/// The trait is unsafe because the implementer needs to make sure that the associated type
/// differs with the implementing type only on their lifetimes. In other words, it's meant to
/// prevent incantations like:
///
/// ```ignore
/// unsafe impl<'a> RebindTo<'a> for Foo<'_> {
///     type Out = Bar<'a>; // !!WRONG!!
/// }
///
/// unsafe impl<'a> RebindTo<'a> for Foo<'_> {
///     type Out = Foo<'a>; // CORRECT
/// }
/// ```
///
/// Users should avoid implementing this trait manually and derive
/// [Rebindable](escher_derive::Rebindable) instead.
pub unsafe trait RebindTo<'a> {
    type Out: 'a;
}

/// Blanket implementation for any reference to owned data
unsafe impl<'a, T: ?Sized + 'static> RebindTo<'a> for &'_ T {
    type Out = &'a T;
}

/// Blanket implementation for any mutable reference to owned data
unsafe impl<'a, T: ?Sized + 'static> RebindTo<'a> for &'_ mut T {
    type Out = &'a mut T;
}

/// Marker trait for any type that implements [RebindTo] for any lifetime. All types can derive
/// this trait using the [Rebindable](escher_derive::Rebindable) derive macro.
pub trait Rebindable: for<'a> RebindTo<'a> {}
impl<T: for<'a> RebindTo<'a>> Rebindable for T {}

/// Type-level function that takes a lifetime `'a` and a type `T` computes a new type `U` that is
/// identical to `T` except for its lifetimes that are now bound to `'a`.
///
/// A type `T` must implement [Rebindable] in order to use this type level function.
///
/// For example:
///
/// * `Rebind<'a, &'static str> == &'a str`
/// * `Rebind<'static, &'a str> == &'static str`
/// * `Rebind<'c, T<'a, 'b>> == T<'c, 'c>`
pub type Rebind<'a, T> = <T as RebindTo<'a>>::Out;

/// A containter of a self referencial struct. The self-referencial struct is constructed with the
/// aid of the async/await machinery of rustc, see [Escher::new].
pub struct Escher<'fut, T> {
    _fut: Pin<Box<dyn Future<Output = ()> + 'fut>>,
    ptr: *mut T,
}

impl<'fut, T: Rebindable> Escher<'fut, T> {
    /// Construct a self referencial struct using the provided closure. The user is expected to
    /// construct the desired data and references to them in the async stack and capture the
    /// desired state when ready.
    ///
    /// ```rust
    /// use escher::Escher;
    ///
    /// let escher_heart = Escher::new(|r| async move {
    ///     let data: Vec<u8> = vec![240, 159, 146, 150];
    ///     let sparkle_heart = std::str::from_utf8(&data).unwrap();
    ///
    ///     r.capture(sparkle_heart).await;
    /// });
    ///
    /// assert_eq!("💖", *escher_heart.as_ref());
    /// ```
    pub fn new<B, F>(builder: B) -> Self
    where
        B: FnOnce(Capturer<T>) -> F,
        F: Future<Output = ()> + 'fut,
    {
        let ptr = Arc::new(AtomicPtr::new(std::ptr::null_mut()));
        let r = Capturer { ptr: ptr.clone() };
        let mut fut = Box::pin(builder(r));

        let waker = noop_waker();
        let mut cx = Context::from_waker(&waker);
        let _ = fut.as_mut().poll(&mut cx);

        // Adversarial code can attempt to capture a value without awaiting on the result
        assert!(
            Arc::strong_count(&ptr) == 2,
            "capture no longer live. Did you forget to .await the result of capture()?"
        );

        let ptr = ptr.load(Ordering::Acquire);

        let low = &*fut as *const _ as usize;
        let high = low + std::mem::size_of_val(&*fut);
        // Adversarial code can attempt to capture a value that does not live on the async stack
        assert!(
            low <= ptr as usize && ptr as usize <= high,
            "captured value outside of async stack. Did you run capture() in a non async function?"
        );

        // SAFETY: At this point we know that:
        // 2. We have a pointer that points into the state of the future
        // 3. The state of the future will never move again because it's behind a Pin<Box<T>>
        // 4. The pointer `ptr` points to a valid instance of T because:
        //    a. T will be kept alive as long as the future is kept alive, and we own it
        //    b. The only way to set the pointer is through Capturer::capture that expects a T
        //    c. The strong count of AtomicPtr is 2, so the async stack is in Capturer::capture_ref because:
        //       α. Capturer is not Clone, so one cannot fake the increased refcount
        //       β. Capturer::capture consumes Capturer so when the function returns the Arc will be dropped
        Escher { _fut: fut, ptr }
    }

    /// Get a shared reference to the inner `T` with its lifetime bound to `&self`
    pub fn as_ref<'a>(&'a self) -> &Rebind<'a, T> {
        // SAFETY
        // Validity of reference
        //    self.ptr points to a valid instance of T in side of self._fut (see safety argument in
        //    constructor)
        // Liveness of reference
        //    The resulting reference is has all its lifetimes bound to the lifetime of self that
        //    contains _fut that contains all the data that ptr could be referring to because it's
        //    a 'static Future
        unsafe { &*(self.ptr as *mut _) }
    }

    /// Get a mut reference to the inner `T` with its lifetime bound to `&mut self`
    pub fn as_mut<'a>(&'a mut self) -> &mut Rebind<'a, T> {
        // SAFETY: see safety argument of Self::as_ref
        unsafe { &mut *(self.ptr as *mut _) }
    }
}

/// An instance of `Capturer` is given to the closure passed to [Escher::new] and is used to
/// capture a reference from the async stack.
pub struct Capturer<T> {
    ptr: Arc<AtomicPtr<T>>,
}

impl<StaticT> Capturer<StaticT> {
    async fn capture_ref<T>(self, val: &mut T)
    where
        // once rustc supports equality constraints this can become: `StaticT = Rebind<'static, T>`
        T: RebindTo<'static, Out = StaticT>,
    {
        self.ptr.store(val as *mut _ as *mut StaticT, Ordering::Release);
        std::future::pending::<()>().await;
    }

    /// Captures the passed value into a future that never resolves.
    /// Callers of this method **must** `.await` it in order for Escher to capture the value.
    pub async fn capture<T>(self, mut val: T)
    where
        // once rustc supports equality constraints this can become: `StaticT = Rebind<'static, T>`
        T: RebindTo<'static, Out = StaticT>,
    {
        self.capture_ref(&mut val).await;
    }
}