1#![warn(clippy::pedantic)]
27
28use futures::lock::BiLock;
29use std::future::Future;
30use std::marker::Unpin;
31use std::pin::Pin;
32use std::task::Waker;
33use std::task::{Context, Poll};
34
35enum State<T> {
36 Incomplete,
37 Waiting(Waker),
38 Complete(Option<T>),
39}
40
41impl<T> State<T> {
42 fn new(value: Option<T>) -> Self {
43 match value {
44 None => Self::Incomplete,
45 v @ Some(_) => Self::Complete(v),
46 }
47 }
48}
49
50pub struct ManualFuture<T: Unpin> {
55 state: BiLock<State<T>>,
56}
57
58pub struct ManualFutureCompleter<T: Unpin> {
63 state: BiLock<State<T>>,
64}
65
66impl<T: Unpin> ManualFutureCompleter<T> {
67 pub async fn complete(self, value: T) {
72 let mut state = self.state.lock().await;
73
74 match std::mem::replace(&mut *state, State::Complete(Some(value))) {
75 State::Incomplete => {}
76 State::Waiting(w) => w.wake(),
77 _ => panic!("future already completed"),
78 }
79 }
80}
81
82impl<T: Unpin> ManualFuture<T> {
83 pub fn new() -> (Self, ManualFutureCompleter<T>) {
86 let (a, b) = BiLock::new(State::new(None));
87 (Self { state: a }, ManualFutureCompleter { state: b })
88 }
89
90 pub fn new_completed(value: T) -> Self {
96 let (state, _) = BiLock::new(State::new(Some(value)));
97 Self { state }
98 }
99}
100
101impl<T: Unpin> Future for ManualFuture<T> {
102 type Output = T;
103
104 fn poll(self: Pin<&mut Self>, cx: &mut Context) -> Poll<Self::Output> {
105 let mut state = match self.state.poll_lock(cx) {
106 Poll::Pending => return Poll::Pending,
107 Poll::Ready(v) => v,
108 };
109
110 match &mut *state {
111 s @ State::Incomplete => *s = State::Waiting(cx.waker().clone()),
112 State::Waiting(w) if w.will_wake(cx.waker()) => {}
113 s @ State::Waiting(_) => *s = State::Waiting(cx.waker().clone()),
114 State::Complete(v) => match v.take() {
115 Some(v) => return Poll::Ready(v),
116 None => panic!("future already polled to completion"),
117 },
118 }
119
120 Poll::Pending
121 }
122}
123
124#[cfg(test)]
125mod tests {
126 use super::*;
127 use futures::executor::block_on;
128 use futures::future::join;
129 use std::thread::sleep;
130 use std::thread::spawn;
131 use std::time::Duration;
132 use tokio::time::timeout;
133
134 #[tokio::test]
135 async fn test_not_completed() {
136 let (future, _) = ManualFuture::<()>::new();
137 timeout(Duration::from_millis(100), future)
138 .await
139 .expect_err("should not complete");
140 }
141
142 #[tokio::test]
143 async fn test_manual_completed() {
144 let (future, completer) = ManualFuture::<()>::new();
145 assert_eq!(join(future, completer.complete(())).await, ((), ()));
146 }
147
148 #[tokio::test]
149 async fn test_pre_completed() {
150 assert_eq!(ManualFuture::new_completed(()).await, ());
151 }
152
153 #[test]
154 fn test_threaded() {
155 let (future, completer) = ManualFuture::<()>::new();
156
157 let t1 = spawn(move || {
158 assert_eq!(block_on(future), ());
159 });
160
161 let t2 = spawn(move || {
162 sleep(Duration::from_millis(100));
163 block_on(async {
164 completer.complete(()).await;
165 });
166 });
167
168 t1.join().unwrap();
169 t2.join().unwrap();
170 }
171}