use futures_util::lock::BiLock;
use std::future::Future;
use std::marker::Unpin;
use std::pin::Pin;
use std::task::Waker;
use std::task::{Context, Poll};
#[derive(Debug)]
enum State<T> {
Incomplete,
Waiting(Waker),
Complete(Option<T>),
}
impl<T> State<T> {
fn new(value: Option<T>) -> Self {
match value {
None => Self::Incomplete,
v @ Some(_) => Self::Complete(v),
}
}
}
#[derive(Debug)]
pub struct ManualFuture<T> {
state: BiLock<State<T>>,
}
#[derive(Debug)]
pub struct ManualFutureCompleter<T> {
state: BiLock<State<T>>,
}
impl<T: Unpin> ManualFutureCompleter<T> {
pub async fn complete(self, value: T) {
let mut state = self.state.lock().await;
match std::mem::replace(&mut *state, State::Complete(Some(value))) {
State::Incomplete => {}
State::Waiting(w) => w.wake(),
State::Complete(_) => unreachable!("future already completed"),
}
}
}
impl<T> ManualFuture<T> {
pub fn new() -> (Self, ManualFutureCompleter<T>) {
let (a, b) = BiLock::new(State::new(None));
(Self { state: a }, ManualFutureCompleter { state: b })
}
pub fn new_completed(value: T) -> Self {
let (state, _) = BiLock::new(State::new(Some(value)));
Self { state }
}
}
impl<T: Unpin> Future for ManualFuture<T> {
type Output = T;
fn poll(self: Pin<&mut Self>, cx: &mut Context) -> Poll<Self::Output> {
let mut state = match self.state.poll_lock(cx) {
Poll::Pending => return Poll::Pending,
Poll::Ready(v) => v,
};
match &mut *state {
s @ State::Incomplete => *s = State::Waiting(cx.waker().clone()),
State::Waiting(w) if w.will_wake(cx.waker()) => {}
s @ State::Waiting(_) => *s = State::Waiting(cx.waker().clone()),
State::Complete(v) => match v.take() {
Some(v) => return Poll::Ready(v),
None => panic!("future already polled to completion"),
},
}
Poll::Pending
}
}
#[cfg(test)]
mod tests {
use super::*;
use futures_executor::block_on;
use std::thread::sleep;
use std::thread::spawn;
use std::time::Duration;
use tokio::time::timeout;
#[tokio::test]
async fn test_not_completed() {
let (future, _) = ManualFuture::<()>::new();
timeout(Duration::from_millis(100), future)
.await
.expect_err("should not complete");
}
#[tokio::test]
async fn test_manual_completed() {
let (future, completer) = ManualFuture::<()>::new();
assert_eq!(tokio::join!(future, completer.complete(())), ((), ()));
}
#[tokio::test]
async fn test_pre_completed() {
assert_eq!(ManualFuture::new_completed(()).await, ());
}
#[test]
fn test_threaded() {
let (future, completer) = ManualFuture::<()>::new();
let t1 = spawn(move || {
assert_eq!(block_on(future), ());
});
let t2 = spawn(move || {
sleep(Duration::from_millis(100));
block_on(async {
completer.complete(()).await;
});
});
t1.join().unwrap();
t2.join().unwrap();
}
}