#![deny(missing_docs)]
use std::mem::MaybeUninit;
use std::sync::atomic::{AtomicUsize, Ordering};
use std::task::Waker;
use std::{
cell::UnsafeCell,
pin::Pin,
ptr::NonNull,
task::{Context, Poll},
};
const IDLE: usize = 0;
const SET: usize = 1;
const SET_WAIT: usize = 2;
const WAIT: usize = 3;
const WAIT_REPLACE: usize = 4;
const DRAIN: usize = 5;
const WASTE: usize = 6;
#[derive(Debug, PartialEq)]
pub struct Waste;
#[repr(C)]
struct Shared<T> {
value: UnsafeCell<MaybeUninit<T>>,
state: AtomicUsize,
waker: UnsafeCell<Option<Waker>>,
}
pub struct Promise<T> {
shared: NonNull<Shared<T>>,
}
unsafe impl<T: Send> Send for Promise<T> {}
unsafe impl<T: Send> Send for Future<T> {}
fn get_poll_state(s: usize) -> usize {
s & 0x07
}
fn set_poll_state(s: usize, poll_state: usize) -> usize {
(s & !0x07) | poll_state
}
fn get_ref_state(s: usize) -> usize {
s >> 3
}
fn desc_ref_state(s: usize) -> usize {
s - 0x08
}
impl<T> Promise<T> {
pub fn set_value(self, value: T) -> Option<T> {
let shared = unsafe { self.shared.as_ref() };
unsafe {
(&mut *shared.value.get()).as_mut_ptr().write(value);
}
self.notify_value()
}
pub fn into_raw(s: Promise<T>) -> NonNull<MaybeUninit<T>> {
let ptr = s.shared.as_ptr() as *mut MaybeUninit<T>;
std::mem::forget(s);
unsafe { NonNull::new_unchecked(ptr) }
}
pub unsafe fn from_raw(ptr: NonNull<MaybeUninit<T>>) -> Promise<T> {
let ptr = ptr.as_ptr() as *mut Shared<T>;
Promise {
shared: NonNull::new_unchecked(ptr),
}
}
fn notify_value(self) -> Option<T> {
let shared = unsafe { self.shared.as_ref() };
let mut state = shared.state.load(Ordering::Relaxed);
let mut poll_state = get_poll_state(state);
loop {
let mut next_state = match poll_state {
IDLE => set_poll_state(state, SET),
WAIT | WAIT_REPLACE => set_poll_state(state, SET_WAIT),
WASTE => unsafe {
return Some((&*shared.value.get()).as_ptr().read());
},
_ => unreachable!("unexpected state {}", poll_state),
};
next_state = desc_ref_state(next_state);
match shared.state.compare_exchange_weak(
state,
next_state,
Ordering::AcqRel,
Ordering::Acquire,
) {
Ok(_) => break,
Err(s) => {
state = s;
poll_state = get_poll_state(state);
}
}
}
if poll_state == WAIT {
unsafe { &mut *shared.waker.get() }.take().unwrap().wake()
}
std::mem::forget(self);
None
}
pub unsafe fn assume_value_set(self) {
self.notify_value();
}
}
impl<T> Drop for Promise<T> {
fn drop(&mut self) {
let shared = unsafe { self.shared.as_ref() };
let mut state = shared.state.load(Ordering::Relaxed);
let mut poll_state;
loop {
poll_state = get_poll_state(state);
let next_state = if let DRAIN | WASTE = poll_state {
desc_ref_state(state)
} else if let SET | SET_WAIT = poll_state {
unreachable!("unexpected state {}", poll_state);
} else {
desc_ref_state(set_poll_state(state, WASTE))
};
match shared.state.compare_exchange_weak(
state,
next_state,
Ordering::AcqRel,
Ordering::Acquire,
) {
Ok(_) => break,
Err(s) => state = s,
}
}
if let IDLE | WAIT_REPLACE = poll_state {
return;
}
if poll_state == WAIT {
unsafe {
(&mut *shared.waker.get()).take().unwrap().wake();
}
}
if get_ref_state(state) == 1 {
unsafe {
Box::from_raw(self.shared.as_ptr());
}
}
}
}
pub struct Future<T> {
shared: NonNull<Shared<T>>,
}
impl<T> Future<T> {
fn poll_idle(&self, shared: &Shared<T>, mut state: usize) -> Poll<Result<T, Waste>> {
loop {
let next_state = set_poll_state(state, WAIT);
match shared.state.compare_exchange_weak(
state,
next_state,
Ordering::AcqRel,
Ordering::Acquire,
) {
Ok(_) => return Poll::Pending,
Err(s) => {
state = s;
match get_poll_state(state) {
IDLE => (),
SET => return self.poll_set(shared, state),
WASTE => {
unsafe { &mut *shared.waker.get() }.take();
return Poll::Ready(Err(Waste));
}
s => unreachable!("unexpected state {}", s),
}
}
}
}
}
fn poll_set(&self, shared: &Shared<T>, mut state: usize) -> Poll<Result<T, Waste>> {
loop {
let next_state = set_poll_state(state, DRAIN);
match shared.state.compare_exchange_weak(
state,
next_state,
Ordering::AcqRel,
Ordering::Acquire,
) {
Ok(_) => unsafe {
let value = &mut *shared.value.get();
return Poll::Ready(Ok(value.as_ptr().read()));
},
Err(s) => {
state = s;
match get_poll_state(state) {
SET | SET_WAIT => (),
s => unreachable!("unexpacted state {}", s),
}
}
}
}
}
fn poll_wait(
&self,
shared: &Shared<T>,
mut state: usize,
ctx: &mut Context,
) -> Poll<Result<T, Waste>> {
loop {
let next_state = set_poll_state(state, WAIT_REPLACE);
match shared.state.compare_exchange_weak(
state,
next_state,
Ordering::AcqRel,
Ordering::Acquire,
) {
Ok(_) => break,
Err(s) => {
state = s;
match get_poll_state(state) {
WAIT => (),
SET_WAIT => return self.poll_set(shared, state),
WASTE => {
unsafe { &mut *shared.waker.get() }.take();
return Poll::Ready(Err(Waste));
}
s => unreachable!("unexpected state {}", s),
}
}
}
}
unsafe {
let w = &mut *shared.waker.get();
if !w.as_ref().unwrap().will_wake(ctx.waker()) {
*w = Some(ctx.waker().clone());
}
}
loop {
let next_state = set_poll_state(state, WAIT);
match shared.state.compare_exchange_weak(
state,
next_state,
Ordering::AcqRel,
Ordering::Acquire,
) {
Ok(_) => return Poll::Pending,
Err(s) => {
state = s;
match get_poll_state(state) {
WAIT_REPLACE => (),
SET_WAIT => {
unsafe {
(&mut *shared.waker.get()).take();
}
return self.poll_set(shared, state);
}
WASTE => {
unsafe { &mut *shared.waker.get() }.take();
return Poll::Ready(Err(Waste));
}
s => unreachable!("unreachable state {}", s),
}
}
}
}
}
}
impl<T> std::future::Future for Future<T> {
type Output = Result<T, Waste>;
fn poll(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> {
let shared = unsafe { self.shared.as_ref() };
let state = shared.state.load(Ordering::Relaxed);
let poll_state = get_poll_state(state);
match poll_state {
IDLE => {
unsafe {
*shared.waker.get() = Some(cx.waker().clone());
}
self.poll_idle(shared, state)
}
SET | SET_WAIT => self.poll_set(shared, state),
WAIT => self.poll_wait(shared, state, cx),
DRAIN | WASTE => Poll::Ready(Err(Waste)),
_ => unreachable!("unexpected state {}", poll_state),
}
}
}
impl<T> Drop for Future<T> {
fn drop(&mut self) {
let shared = unsafe { self.shared.as_ref() };
let mut state = shared.state.load(Ordering::Acquire);
loop {
let poll_state = get_poll_state(state);
let next_state = match poll_state {
DRAIN => desc_ref_state(state),
WASTE => {
unsafe { &mut *shared.waker.get() }.take();
desc_ref_state(state)
}
_ => desc_ref_state(set_poll_state(state, WASTE)),
};
match shared.state.compare_exchange_weak(
state,
next_state,
Ordering::AcqRel,
Ordering::Acquire,
) {
Ok(_) => break,
Err(s) => state = s,
}
}
if let SET | SET_WAIT = get_poll_state(state) {
unsafe {
(&mut *shared.value.get()).as_mut_ptr().read();
}
}
if WAIT == get_poll_state(state) {
unsafe { &mut *shared.waker.get() }.take();
}
if get_ref_state(state) == 1 {
unsafe {
Box::from_raw(self.shared.as_ptr());
}
}
}
}
pub fn pair<T>() -> (Future<T>, Promise<T>) {
let shared = unsafe {
NonNull::new_unchecked(Box::into_raw(Box::new(Shared {
state: AtomicUsize::new(0x80 * 2),
value: UnsafeCell::new(MaybeUninit::uninit()),
waker: UnsafeCell::new(None),
})))
};
(Future { shared }, Promise { shared })
}
#[cfg(test)]
mod tests {
use super::*;
#[tokio::test]
async fn test_wake() {
let (f, p) = crate::pair();
assert_eq!(None, p.set_value("test"));
let res = f.await;
assert_eq!(res, Ok("test"));
}
#[tokio::test]
async fn test_drop_promise() {
let (f, p) = crate::pair::<usize>();
drop(p);
let res = f.await;
assert_eq!(res, Err(Waste));
}
#[tokio::test]
async fn test_drop_future() {
let (f, p) = crate::pair();
drop(f);
assert_eq!(Some("test"), p.set_value("test"));
}
#[tokio::test]
async fn test_block_future() {
let (f, p) = crate::pair();
tokio::spawn(async move {
let _ = p.set_value("test");
});
let res = f.await;
assert_eq!(res, Ok("test"));
}
#[tokio::test]
async fn test_assume_init() {
let (f, p) = crate::pair();
tokio::spawn(async move {
let ptr = Promise::into_raw(p);
unsafe {
ptr.as_ptr().write(MaybeUninit::new("test"));
Promise::from_raw(ptr).assume_value_set();
}
});
let res = f.await;
assert_eq!(res, Ok("test"));
}
}