use std::{
future::Future,
pin::Pin,
sync::{Arc, Condvar, Mutex},
task::{Context, Poll, RawWaker, RawWakerVTable, Waker},
};
pub fn block_on<F>(future: F) -> F::Output
where
F: Future,
{
let parker = Arc::new(Parker::default());
let waker = unsafe { Waker::from_raw(raw_waker(Arc::into_raw(parker.clone()) as *const ())) };
let mut future = Pin::from(Box::new(future));
let mut context = Context::from_waker(&waker);
loop {
parker.clear();
if let Poll::Ready(value) = future.as_mut().poll(&mut context) {
return value;
}
parker.park();
}
}
#[derive(Default)]
struct Parker {
notified: Mutex<bool>,
ready: Condvar,
}
impl Parker {
fn clear(&self) {
*self.notified.lock().unwrap() = false;
}
fn wake(&self) {
let mut notified = self.notified.lock().unwrap();
*notified = true;
self.ready.notify_one();
}
fn park(&self) {
let mut notified = self.notified.lock().unwrap();
while !*notified {
notified = self.ready.wait(notified).unwrap();
}
}
}
unsafe fn raw_waker(pointer: *const ()) -> RawWaker {
RawWaker::new(pointer, &VTABLE)
}
unsafe fn clone_waker(pointer: *const ()) -> RawWaker {
unsafe {
std::sync::Arc::<Parker>::increment_strong_count(pointer.cast::<Parker>());
raw_waker(pointer)
}
}
unsafe fn wake_waker(pointer: *const ()) {
let parker = unsafe { Arc::from_raw(pointer.cast::<Parker>()) };
parker.wake();
}
unsafe fn wake_by_ref_waker(pointer: *const ()) {
let parker = unsafe { Arc::from_raw(pointer.cast::<Parker>()) };
parker.wake();
let _ = Arc::into_raw(parker);
}
unsafe fn drop_waker(pointer: *const ()) {
drop(unsafe { Arc::from_raw(pointer.cast::<Parker>()) });
}
static VTABLE: RawWakerVTable =
RawWakerVTable::new(clone_waker, wake_waker, wake_by_ref_waker, drop_waker);
#[cfg(test)]
mod tests {
use super::block_on;
use std::{
future::Future,
pin::Pin,
sync::{
Arc,
atomic::{AtomicBool, Ordering},
},
task::{Context, Poll},
thread,
time::Duration,
};
#[test]
fn block_on_runs_ready_future() {
let value = block_on(async { 42 });
assert_eq!(value, 42);
}
#[test]
fn block_on_observes_cross_thread_wakes() {
struct ThreadWakeFuture {
started: bool,
ready: Arc<AtomicBool>,
}
impl Future for ThreadWakeFuture {
type Output = u32;
fn poll(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> {
if self.ready.load(Ordering::SeqCst) {
return Poll::Ready(7);
}
if !self.started {
self.started = true;
let ready = Arc::clone(&self.ready);
let waker = cx.waker().clone();
thread::spawn(move || {
thread::sleep(Duration::from_millis(5));
ready.store(true, Ordering::SeqCst);
waker.wake();
});
}
Poll::Pending
}
}
let value = block_on(ThreadWakeFuture {
started: false,
ready: Arc::new(AtomicBool::new(false)),
});
assert_eq!(value, 7);
}
#[test]
fn block_on_handles_repeated_wake_by_ref_cycles() {
struct SelfWakingFuture {
remaining: usize,
}
impl Future for SelfWakingFuture {
type Output = usize;
fn poll(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> {
if self.remaining == 0 {
return Poll::Ready(123);
}
self.remaining -= 1;
cx.waker().wake_by_ref();
Poll::Pending
}
}
let value = block_on(SelfWakingFuture { remaining: 256 });
assert_eq!(value, 123);
}
}