use std::cell::UnsafeCell;
use std::fmt::Debug;
use std::mem::{ManuallyDrop, MaybeUninit};
use std::pin::Pin;
use std::sync::Arc;
use std::sync::atomic::{AtomicU8, Ordering};
use std::task::{Context, Poll};
pub mod sync;
#[repr(u8)]
enum State {
Empty,
Data,
Gone,
FutureHangup,
}
#[derive(Debug)]
struct Shared<R> {
data: UnsafeCell<MaybeUninit<R>>,
state: AtomicU8,
waker: atomic_waker::AtomicWaker,
}
#[derive(Debug)]
pub struct Sender<R> {
shared: Arc<Shared<R>>,
sent: bool,
}
#[derive(Debug)]
pub struct FutureCancel<R, C: FutureCancellation> {
future: ManuallyDrop<Future<R>>,
cancellation: C,
}
pub trait FutureCancellation {
fn cancel(&mut self);
}
pub fn continuation<R>() -> (Sender<R>, Future<R>) {
let shared = Arc::new(Shared {
data: UnsafeCell::new(MaybeUninit::uninit()),
state: AtomicU8::new(State::Empty as u8),
waker: atomic_waker::AtomicWaker::new(),
});
(
Sender {
shared: shared.clone(),
sent: false,
},
Future { shared },
)
}
pub fn continuation_cancel<R, C: FutureCancellation>(
cancellation: C,
) -> (Sender<R>, FutureCancel<R, C>) {
let shared = Arc::new(Shared {
data: UnsafeCell::new(MaybeUninit::uninit()),
state: AtomicU8::new(State::Empty as u8),
waker: atomic_waker::AtomicWaker::new(),
});
(
Sender {
shared: shared.clone(),
sent: false,
},
FutureCancel {
future: ManuallyDrop::new(Future { shared }),
cancellation,
},
)
}
impl<R> Sender<R> {
pub fn send(mut self, data: R) {
self.sent = true;
unsafe {
let opt = &mut *self.shared.data.get();
std::ptr::write(opt.as_mut_ptr(), data); }
loop {
let swap = self.shared.state.compare_exchange_weak(
State::Empty as u8,
State::Data as u8,
Ordering::Release,
Ordering::Relaxed,
);
match swap {
Ok(_) => {
self.shared.waker.wake();
return;
}
Err(u) => {
match u {
u if u == State::Empty as u8 => { }
u if u == State::Data as u8 || u == State::Gone as u8 => {
unreachable!("Continuation already resumed")
}
u if u == State::FutureHangup as u8 => {
unsafe {
let data = &mut *self.shared.data.get();
let _ = data.assume_init_read();
}
return;
}
_ => unreachable!("Invalid state"),
}
}
}
}
}
pub fn is_cancelled(&self) -> bool {
self.shared.state.load(Ordering::Relaxed) == State::FutureHangup as u8
}
}
impl<R> Drop for Sender<R> {
fn drop(&mut self) {
if !std::thread::panicking() {
assert!(self.sent, "Sender dropped without sending data");
}
}
}
#[derive(Debug)]
pub struct Future<R> {
shared: Arc<Shared<R>>,
}
enum DropState {
Cancelled,
NotCancelled,
}
impl<R> Future<R> {
fn drop_impl(&mut self) -> DropState {
let swap = self
.shared
.state
.swap(State::FutureHangup as u8, Ordering::Acquire);
match swap {
u if u == State::Empty as u8 => DropState::Cancelled,
u if u == State::Data as u8 => {
unsafe {
let data = &mut *self.shared.data.get();
let _ = data.assume_init_read();
}
DropState::NotCancelled
}
u if u == State::Gone as u8 => DropState::NotCancelled,
_ => unreachable!("Invalid state"),
}
}
}
impl<R> Drop for Future<R> {
fn drop(&mut self) {
self.drop_impl();
}
}
impl<R, C: FutureCancellation> Drop for FutureCancel<R, C> {
fn drop(&mut self) {
let mut future = unsafe { ManuallyDrop::take(&mut self.future) };
match future.drop_impl() {
DropState::Cancelled => {
self.cancellation.cancel();
}
DropState::NotCancelled => {}
}
std::mem::forget(future);
}
}
enum ReadStatus<R> {
Data(R),
Waiting,
Spurious,
}
impl<R> Future<R> {
fn interpret_result(
result: Result<u8, u8>,
data: &UnsafeCell<MaybeUninit<R>>,
) -> ReadStatus<R> {
match result {
Ok(..) => {
unsafe {
let data = &mut *data.get();
let r = data.assume_init_read();
ReadStatus::Data(r)
}
}
Err(u) => match u {
u if u == State::Empty as u8 => ReadStatus::Waiting,
u if u == State::Data as u8 => ReadStatus::Spurious,
u if u == State::Gone as u8 => {
panic!("Continuation already polled")
}
_ => {
unreachable!("Invalid state")
}
},
}
}
}
impl<R> std::future::Future for Future<R> {
type Output = R;
fn poll(self: std::pin::Pin<&mut Self>, cx: &mut std::task::Context<'_>) -> Poll<Self::Output> {
let state = self.shared.state.compare_exchange_weak(
State::Data as u8,
State::Gone as u8,
Ordering::Acquire,
Ordering::Relaxed,
);
match Self::interpret_result(state, &self.shared.data) {
ReadStatus::Data(data) => return Poll::Ready(data),
ReadStatus::Waiting | ReadStatus::Spurious => {}
}
self.shared.waker.register(cx.waker());
loop {
let state2 = self.shared.state.compare_exchange_weak(
State::Data as u8,
State::Gone as u8,
Ordering::Acquire,
Ordering::Relaxed,
);
match Self::interpret_result(state2, &self.shared.data) {
ReadStatus::Data(data) => return Poll::Ready(data),
ReadStatus::Waiting => return Poll::Pending,
ReadStatus::Spurious => continue,
}
}
}
}
impl<R, C: FutureCancellation> std::future::Future for FutureCancel<R, C> {
type Output = R;
fn poll(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> {
unsafe { self.map_unchecked_mut(|s| &mut s.future as &mut Future<R>) }.poll(cx)
}
}
unsafe impl<R: Send> Send for Future<R> {}
unsafe impl<R: Send, C: Send + FutureCancellation> Send for FutureCancel<R, C> {}
unsafe impl<R: Send> Send for Sender<R> {}
#[cfg(test)]
mod test {
use crate::continuation;
use std::pin::Pin;
use std::task::Poll;
#[cfg(target_arch = "wasm32")]
wasm_bindgen_test::wasm_bindgen_test_configure!(run_in_browser);
#[cfg_attr(not(target_arch = "wasm32"), test)]
#[cfg_attr(target_arch = "wasm32", wasm_bindgen_test::wasm_bindgen_test)]
fn test_continue() {
let (c, mut f) = continuation();
let mut f = Pin::new(&mut f);
assert!(test_executors::poll_once(f.as_mut()).is_pending());
c.send(23);
match test_executors::poll_once(f) {
Poll::Ready(23) => {}
x => panic!("Unexpected result {:?}", x),
}
}
#[test]
fn test_is_send() {
fn is_send<T: Send>() {}
is_send::<crate::Future<i32>>();
is_send::<crate::Sender<i32>>();
}
#[test_executors::async_test]
async fn test_stress() {
#[cfg(not(target_arch = "wasm32"))]
use std::thread;
#[cfg(target_arch = "wasm32")]
use wasm_thread as thread;
let mut senders = Vec::new();
let mut futs = Vec::new();
#[allow(dead_code)]
#[derive(Debug)]
struct Ex {
a: u64,
b: u64,
}
impl Ex {
fn new() -> Ex {
Ex { a: 0, b: 0 }
}
}
for _ in 0..1000 {
let (s, f) = continuation();
senders.push(s);
futs.push(f);
}
let (overall_send, overall_fut) = continuation::<()>();
thread::spawn(|| {
test_executors::sleep_on(async move {
for (_f, fut) in futs.drain(..).enumerate() {
let r = fut.await;
println!("{:?}", r);
}
overall_send.send(());
});
});
for sender in senders.drain(..) {
sender.send(Ex::new());
}
overall_fut.await;
}
}