use std::future::Future;
use std::pin::Pin;
use std::sync::atomic::{AtomicPtr, Ordering};
use std::sync::Arc;
use std::task::Context;
use futures_task::noop_waker;
pub unsafe trait RebindTo<'a> {
type Out: 'a;
}
unsafe impl<'a, T: ?Sized + 'static> RebindTo<'a> for &'_ T {
type Out = &'a T;
}
unsafe impl<'a, T: ?Sized + 'static> RebindTo<'a> for &'_ mut T {
type Out = &'a mut T;
}
pub trait Rebindable: for<'a> RebindTo<'a> {}
impl<T: for<'a> RebindTo<'a>> Rebindable for T {}
pub type Rebind<'a, T> = <T as RebindTo<'a>>::Out;
pub struct Escher<'fut, T> {
_fut: Pin<Box<dyn Future<Output = ()> + 'fut>>,
ptr: *mut T,
}
impl<'fut, T: Rebindable> Escher<'fut, T> {
pub fn new<B, F>(builder: B) -> Self
where
B: FnOnce(Capturer<T>) -> F,
F: Future<Output = ()> + 'fut,
{
let ptr = Arc::new(AtomicPtr::new(std::ptr::null_mut()));
let r = Capturer { ptr: ptr.clone() };
let mut fut = Box::pin(builder(r));
let waker = noop_waker();
let mut cx = Context::from_waker(&waker);
let _ = fut.as_mut().poll(&mut cx);
assert!(
Arc::strong_count(&ptr) == 2,
"capture no longer live. Did you forget to .await the result of capture()?"
);
let ptr = ptr.load(Ordering::Acquire);
let low = &*fut as *const _ as usize;
let high = low + std::mem::size_of_val(&*fut);
assert!(
low <= ptr as usize && ptr as usize <= high,
"captured value outside of async stack. Did you run capture() in a non async function?"
);
Escher { _fut: fut, ptr }
}
pub fn as_ref<'a>(&'a self) -> &Rebind<'a, T> {
unsafe { &*(self.ptr as *mut _) }
}
pub fn as_mut<'a>(&'a mut self) -> &mut Rebind<'a, T> {
unsafe { &mut *(self.ptr as *mut _) }
}
}
pub struct Capturer<T> {
ptr: Arc<AtomicPtr<T>>,
}
impl<StaticT> Capturer<StaticT> {
async fn capture_ref<T>(self, val: &mut T)
where
T: RebindTo<'static, Out = StaticT>,
{
self.ptr.store(val as *mut _ as *mut StaticT, Ordering::Release);
std::future::pending::<()>().await;
}
pub async fn capture<T>(self, mut val: T)
where
T: RebindTo<'static, Out = StaticT>,
{
self.capture_ref(&mut val).await;
}
}