1use futures_util::lock::BiLock;
27use std::future::Future;
28use std::marker::Unpin;
29use std::pin::Pin;
30use std::task::Waker;
31use std::task::{Context, Poll};
32
33#[derive(Debug)]
34enum State<T> {
35 Incomplete,
36 Waiting(Waker),
37 Complete(Option<T>),
38}
39
40impl<T> State<T> {
41 fn new(value: Option<T>) -> Self {
42 match value {
43 None => Self::Incomplete,
44 v @ Some(_) => Self::Complete(v),
45 }
46 }
47}
48
49#[derive(Debug)]
54pub struct ManualFuture<T> {
55 state: BiLock<State<T>>,
56}
57
58#[derive(Debug)]
63pub struct ManualFutureCompleter<T> {
64 state: BiLock<State<T>>,
65}
66
67impl<T: Unpin> ManualFutureCompleter<T> {
68 pub async fn complete(self, value: T) {
73 let mut state = self.state.lock().await;
74
75 match std::mem::replace(&mut *state, State::Complete(Some(value))) {
76 State::Incomplete => {}
77 State::Waiting(w) => w.wake(),
78 State::Complete(_) => unreachable!("future already completed"),
79 }
80 }
81}
82
83impl<T> ManualFuture<T> {
84 pub fn new() -> (Self, ManualFutureCompleter<T>) {
87 let (a, b) = BiLock::new(State::new(None));
88 (Self { state: a }, ManualFutureCompleter { state: b })
89 }
90
91 pub fn new_completed(value: T) -> Self {
97 let (state, _) = BiLock::new(State::new(Some(value)));
98 Self { state }
99 }
100}
101
102impl<T: Unpin> Future for ManualFuture<T> {
103 type Output = T;
104
105 fn poll(self: Pin<&mut Self>, cx: &mut Context) -> Poll<Self::Output> {
106 let mut state = match self.state.poll_lock(cx) {
107 Poll::Pending => return Poll::Pending,
108 Poll::Ready(v) => v,
109 };
110
111 match &mut *state {
112 s @ State::Incomplete => *s = State::Waiting(cx.waker().clone()),
113 State::Waiting(w) if w.will_wake(cx.waker()) => {}
114 s @ State::Waiting(_) => *s = State::Waiting(cx.waker().clone()),
115 State::Complete(v) => match v.take() {
116 Some(v) => return Poll::Ready(v),
117 None => panic!("future already polled to completion"),
118 },
119 }
120
121 Poll::Pending
122 }
123}
124
125#[cfg(test)]
126mod tests {
127 use super::*;
128 use futures_executor::block_on;
129 use std::thread::sleep;
130 use std::thread::spawn;
131 use std::time::Duration;
132 use tokio::time::timeout;
133
134 #[tokio::test]
135 async fn test_not_completed() {
136 let (future, _) = ManualFuture::<()>::new();
137 timeout(Duration::from_millis(100), future)
138 .await
139 .expect_err("should not complete");
140 }
141
142 #[tokio::test]
143 async fn test_manual_completed() {
144 let (future, completer) = ManualFuture::<()>::new();
145 assert_eq!(tokio::join!(future, completer.complete(())), ((), ()));
146 }
147
148 #[tokio::test]
149 async fn test_pre_completed() {
150 assert_eq!(ManualFuture::new_completed(()).await, ());
151 }
152
153 #[test]
154 fn test_threaded() {
155 let (future, completer) = ManualFuture::<()>::new();
156
157 let t1 = spawn(move || {
158 assert_eq!(block_on(future), ());
159 });
160
161 let t2 = spawn(move || {
162 sleep(Duration::from_millis(100));
163 block_on(async {
164 completer.complete(()).await;
165 });
166 });
167
168 t1.join().unwrap();
169 t2.join().unwrap();
170 }
171}