use core::future::Future;
use core::mem;
use core::task::{Context, Poll, RawWaker, RawWakerVTable, Waker};
pub trait ThreadLocal: Sized {
fn current() -> Self;
}
pub trait ParkHandle: Sized {
fn park(&self);
fn unpark(&self);
}
pub trait Wakeable: ParkHandle + Clone {
fn into_opaque(self) -> *const ();
unsafe fn from_opaque(data: *const ()) -> Self;
unsafe fn raw_waker(&self) -> RawWaker {
let data = self.clone().into_opaque();
RawWaker::new(
data,
&RawWakerVTable::new(
Self::clone_waker,
Self::wake,
Self::wake_by_ref,
Self::drop_waker,
),
)
}
fn waker(&self) -> Waker {
unsafe { Waker::from_raw(self.raw_waker()) }
}
unsafe fn clone_waker(data: *const ()) -> RawWaker {
let waker = Self::from_opaque(data);
let ret = waker.raw_waker();
mem::forget(waker);
ret
}
unsafe fn wake(data: *const ()) {
let waker = Self::from_opaque(data);
waker.unpark();
}
unsafe fn wake_by_ref(data: *const ()) {
let waker = Self::from_opaque(data);
waker.unpark();
mem::forget(waker);
}
unsafe fn drop_waker(data: *const ()) {
let _ = Self::from_opaque(data);
}
}
impl ThreadLocal for *const () {
fn current() -> Self {
core::ptr::null()
}
}
impl ParkHandle for *const () {
fn park(&self) {
core::hint::spin_loop()
}
fn unpark(&self) {}
}
impl Wakeable for *const () {
fn into_opaque(self) -> *const () {
self
}
unsafe fn from_opaque(data: *const ()) -> Self {
data
}
}
pub fn block_on_t<T: ParkHandle + Wakeable + ThreadLocal, F: Future>(fut: F) -> F::Output {
let handle = T::current();
let waker = handle.waker();
block_on_handle(fut, &handle, &waker)
}
pub fn block_on_handle<T: ParkHandle, F: Future>(
mut fut: F,
handle: &T,
waker: &Waker,
) -> F::Output {
let mut fut = unsafe { core::pin::Pin::new_unchecked(&mut fut) };
let mut context = Context::from_waker(waker);
loop {
match fut.as_mut().poll(&mut context) {
Poll::Pending => handle.park(),
Poll::Ready(item) => break item,
}
}
}
pub fn block_on<F: Future>(fut: F) -> F::Output {
#[cfg(feature = "std")]
return block_on_t::<LocalThread, _>(fut);
#[cfg(not(feature = "std"))]
return block_on_t::<*const (), _>(fut);
}
#[cfg(feature = "std")]
pub use std_impl::LocalThread;
#[cfg(feature = "std")]
mod std_impl {
use super::*;
use std::sync::{Arc, Condvar, Mutex};
#[derive(Default)]
struct Signal {
signaled: Mutex<bool>,
cond: Condvar,
}
impl Signal {
fn wait(&self) {
let mut signaled = self
.cond
.wait_while(self.signaled.lock().unwrap(), |signaled| !*signaled)
.unwrap();
*signaled = false;
}
fn wake(&self) {
let mut signaled = self.signaled.lock().unwrap();
self.cond.notify_one();
*signaled = true;
}
}
thread_local! {
static ACCESS: Arc<Signal> = Arc::new(Signal::default());
}
#[derive(Clone)]
pub struct LocalThread(Arc<Signal>);
impl ThreadLocal for LocalThread {
fn current() -> Self {
LocalThread(ACCESS.with(Clone::clone))
}
}
impl ParkHandle for LocalThread {
fn park(&self) {
self.0.wait();
}
fn unpark(&self) {
self.0.wake();
}
}
impl Wakeable for LocalThread {
fn into_opaque(self) -> *const () {
unsafe { mem::transmute::<_, *const ()>(self) }
}
unsafe fn from_opaque(data: *const ()) -> Self {
mem::transmute(data)
}
}
}