1use std::future::Future;
2use std::pin::Pin;
3use std::task::{Context, Poll};
4
5type BoxFuture<'a, T> = Pin<Box<dyn Future<Output = T> + Send + 'a>>;
6
7pub struct Promise<T, E> {
8 future: BoxFuture<'static, Result<T, E>>,
9}
10
11impl<T: 'static + Send, E: 'static + Send> Promise<T, E> {
12 pub fn new<F>(future: F) -> Self
13 where
14 F: Future<Output = Result<T, E>> + Send + 'static,
15 {
16 Self {
17 future: Box::pin(future),
18 }
19 }
20
21 pub fn resolve(value: T) -> Self {
22 Self::new(async move { Ok(value) })
23 }
24
25 pub fn reject(error: E) -> Self {
26 Self::new(async move { Err(error) })
27 }
28
29 pub fn all<I>(promises: I) -> Promise<Vec<T>, E>
30 where
31 I: IntoIterator<Item = Promise<T, E>>,
32 {
33 let promises = promises.into_iter().collect::<Vec<_>>();
34 let future = async move {
35 let mut result = Vec::with_capacity(promises.len());
36 for promise in promises {
37 result.push(promise.future.await?);
38 }
39 Ok(result)
40 };
41 Promise::new(future)
42 }
43
44 pub fn any<I>(promises: I) -> Promise<T, E>
45 where
46 I: IntoIterator<Item = Promise<T, E>>,
47 {
48 let mut promises = promises
49 .into_iter()
50 .map(|p| Box::pin(p.future))
51 .collect::<Vec<_>>();
52 let future = async move {
53 loop {
54 for future in promises.iter_mut() {
55 let polled = std::future::poll_fn(|cx| match future.as_mut().poll(cx) {
56 Poll::Ready(result) => Poll::Ready(Some(result)),
57 Poll::Pending => Poll::Pending,
58 })
59 .await;
60 if let Some(result) = polled {
61 return result;
62 }
63 }
64 }
65 };
66 Promise::new(future)
67 }
68
69 pub fn then<T2, Fut>(self, f: impl FnOnce(T) -> Fut + Send + 'static) -> Promise<T2, E>
70 where
71 Fut: Future<Output = Result<T2, E>> + Send + 'static,
72 T2: 'static + Send,
73 {
74 let future = async move {
75 match self.future.await {
76 Ok(val) => f(val).await,
77 Err(err) => Err(err),
78 }
79 };
80 Promise::new(future)
81 }
82
83 pub fn catch<Fut>(self, f: impl FnOnce(E) -> Fut + Send + 'static) -> Promise<T, E>
84 where
85 Fut: Future<Output = Result<T, E>> + Send + 'static,
86 {
87 let future = async move {
88 match self.future.await {
89 Ok(val) => Ok(val),
90 Err(err) => f(err).await,
91 }
92 };
93 Promise::new(future)
94 }
95
96 pub fn transform<T2, E2, Fut>(
97 self,
98 f: impl FnOnce(Result<T, E>) -> Fut + Send + 'static,
99 ) -> Promise<T2, E2>
100 where
101 Fut: Future<Output = Result<T2, E2>> + Send + 'static,
102 T2: 'static + Send,
103 E2: 'static + Send,
104 {
105 let future = async move {
106 let result = self.await;
107 f(result).await
108 };
109 Promise::new(future)
110 }
111
112 pub fn finally<Fut>(self, f: impl FnOnce() -> Fut + Send + 'static) -> Promise<T, E>
113 where
114 Fut: Future<Output = ()> + Send + 'static,
115 {
116 let future = async move {
117 let result = self.await;
118 f().await;
119 result
120 };
121 Promise::new(future)
122 }
123}
124
125impl<T, E> Future for Promise<T, E> {
126 type Output = Result<T, E>;
127
128 fn poll(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> {
129 self.future.as_mut().poll(cx)
130 }
131}
132
133#[cfg(test)]
134mod tests {
135 use super::*;
136 use std::{
137 future::poll_fn,
138 sync::{Arc, Mutex},
139 time::{Duration, Instant},
140 };
141
142 async fn delay(duration: Duration) {
143 let timer = Instant::now();
144 poll_fn(|cx| {
145 if timer.elapsed() >= duration {
146 cx.waker().wake_by_ref();
147 Poll::Ready(())
148 } else {
149 cx.waker().wake_by_ref();
150 Poll::Pending
151 }
152 })
153 .await;
154 }
155
156 #[pollster::test]
157 async fn test_promise_resolve() {
158 let promise = Promise::<i32, &str>::resolve(42);
159 assert_eq!(promise.await.unwrap(), 42);
160 }
161
162 #[pollster::test]
163 async fn test_promise_reject() {
164 let promise = Promise::<i32, &str>::reject("error");
165 assert_eq!(promise.await.unwrap_err(), "error");
166 }
167
168 #[pollster::test]
169 async fn test_promise_then() {
170 let promise = Promise::<i32, &str>::resolve(2).then(|x| async move { Ok(x * 3) });
171 assert_eq!(promise.await.unwrap(), 6);
172 }
173
174 #[pollster::test]
175 async fn test_promise_catch() {
176 let promise = Promise::<i32, &str>::reject("error").catch(|_| async { Ok(99) });
177 assert_eq!(promise.await.unwrap(), 99);
178 }
179
180 #[pollster::test]
181 async fn test_promise_finally() {
182 let flag = Arc::new(Mutex::new(false));
183 let flag_clone = Arc::clone(&flag);
184 let promise = Promise::<i32, &str>::resolve(5).finally(move || async move {
185 let mut flag = flag_clone.lock().unwrap();
186 *flag = true;
187 });
188 assert_eq!(promise.await.unwrap(), 5);
189 assert!(*flag.lock().unwrap());
190 }
191
192 #[pollster::test]
193 async fn test_promise_transform() {
194 let promise = Promise::<i32, &str>::resolve(10)
195 .transform(|result| async move { result.map(|v| v * 2) });
196 assert_eq!(promise.await.unwrap(), 20);
197 }
198
199 #[pollster::test]
200 async fn test_promise_all() {
201 let promise = Promise::all([
202 Promise::<i32, &str>::resolve(1),
203 Promise::<i32, &str>::resolve(2),
204 Promise::<i32, &str>::resolve(3),
205 ]);
206 assert_eq!(promise.await.unwrap(), vec![1, 2, 3]);
207 }
208
209 #[pollster::test]
210 async fn test_promise_any() {
211 let promise = Promise::any([
212 Promise::<i32, &str>::new(async {
213 delay(Duration::from_millis(10)).await;
214 Ok(1)
215 }),
216 Promise::<i32, &str>::new(async {
217 delay(Duration::from_millis(0)).await;
218 Ok(2)
219 }),
220 Promise::<i32, &str>::new(async {
221 delay(Duration::from_millis(30)).await;
222 Ok(3)
223 }),
224 ]);
225 assert_eq!(promise.await.unwrap(), 1);
226 }
227
228 #[pollster::test]
229 async fn test_promise_chain() {
230 let promise = Promise::<i32, String>::new(async {
231 delay(Duration::from_millis(10)).await;
232 Ok(1)
233 })
234 .then(|x| async move { Ok(x + 3) })
235 .transform(|result| async move { result.map(|v| v * 2) })
236 .catch(|err| async move {
237 println!("Caught an error: {}", err);
238 Ok(0)
239 })
240 .finally(|| async { println!("Done") });
241 assert_eq!(promise.await.unwrap(), 8);
242 }
243}