anput_promise/
lib.rs

1use std::future::Future;
2use std::pin::Pin;
3use std::task::{Context, Poll};
4
5type BoxFuture<'a, T> = Pin<Box<dyn Future<Output = T> + Send + 'a>>;
6
7pub struct Promise<T, E> {
8    future: BoxFuture<'static, Result<T, E>>,
9}
10
11impl<T: 'static + Send, E: 'static + Send> Promise<T, E> {
12    pub fn new<F>(future: F) -> Self
13    where
14        F: Future<Output = Result<T, E>> + Send + 'static,
15    {
16        Self {
17            future: Box::pin(future),
18        }
19    }
20
21    pub fn resolve(value: T) -> Self {
22        Self::new(async move { Ok(value) })
23    }
24
25    pub fn reject(error: E) -> Self {
26        Self::new(async move { Err(error) })
27    }
28
29    pub fn all<I>(promises: I) -> Promise<Vec<T>, E>
30    where
31        I: IntoIterator<Item = Promise<T, E>>,
32    {
33        let promises = promises.into_iter().collect::<Vec<_>>();
34        let future = async move {
35            let mut result = Vec::with_capacity(promises.len());
36            for promise in promises {
37                result.push(promise.future.await?);
38            }
39            Ok(result)
40        };
41        Promise::new(future)
42    }
43
44    pub fn any<I>(promises: I) -> Promise<T, E>
45    where
46        I: IntoIterator<Item = Promise<T, E>>,
47    {
48        let mut promises = promises
49            .into_iter()
50            .map(|p| Box::pin(p.future))
51            .collect::<Vec<_>>();
52        let future = async move {
53            loop {
54                for future in promises.iter_mut() {
55                    let polled = std::future::poll_fn(|cx| match future.as_mut().poll(cx) {
56                        Poll::Ready(result) => Poll::Ready(Some(result)),
57                        Poll::Pending => Poll::Pending,
58                    })
59                    .await;
60                    if let Some(result) = polled {
61                        return result;
62                    }
63                }
64            }
65        };
66        Promise::new(future)
67    }
68
69    pub fn then<T2, Fut>(self, f: impl FnOnce(T) -> Fut + Send + 'static) -> Promise<T2, E>
70    where
71        Fut: Future<Output = Result<T2, E>> + Send + 'static,
72        T2: 'static + Send,
73    {
74        let future = async move {
75            match self.future.await {
76                Ok(val) => f(val).await,
77                Err(err) => Err(err),
78            }
79        };
80        Promise::new(future)
81    }
82
83    pub fn catch<Fut>(self, f: impl FnOnce(E) -> Fut + Send + 'static) -> Promise<T, E>
84    where
85        Fut: Future<Output = Result<T, E>> + Send + 'static,
86    {
87        let future = async move {
88            match self.future.await {
89                Ok(val) => Ok(val),
90                Err(err) => f(err).await,
91            }
92        };
93        Promise::new(future)
94    }
95
96    pub fn transform<T2, E2, Fut>(
97        self,
98        f: impl FnOnce(Result<T, E>) -> Fut + Send + 'static,
99    ) -> Promise<T2, E2>
100    where
101        Fut: Future<Output = Result<T2, E2>> + Send + 'static,
102        T2: 'static + Send,
103        E2: 'static + Send,
104    {
105        let future = async move {
106            let result = self.await;
107            f(result).await
108        };
109        Promise::new(future)
110    }
111
112    pub fn finally<Fut>(self, f: impl FnOnce() -> Fut + Send + 'static) -> Promise<T, E>
113    where
114        Fut: Future<Output = ()> + Send + 'static,
115    {
116        let future = async move {
117            let result = self.await;
118            f().await;
119            result
120        };
121        Promise::new(future)
122    }
123}
124
125impl<T, E> Future for Promise<T, E> {
126    type Output = Result<T, E>;
127
128    fn poll(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> {
129        self.future.as_mut().poll(cx)
130    }
131}
132
133#[cfg(test)]
134mod tests {
135    use super::*;
136    use std::{
137        future::poll_fn,
138        sync::{Arc, Mutex},
139        time::{Duration, Instant},
140    };
141
142    async fn delay(duration: Duration) {
143        let timer = Instant::now();
144        poll_fn(|cx| {
145            if timer.elapsed() >= duration {
146                cx.waker().wake_by_ref();
147                Poll::Ready(())
148            } else {
149                cx.waker().wake_by_ref();
150                Poll::Pending
151            }
152        })
153        .await;
154    }
155
156    #[pollster::test]
157    async fn test_promise_resolve() {
158        let promise = Promise::<i32, &str>::resolve(42);
159        assert_eq!(promise.await.unwrap(), 42);
160    }
161
162    #[pollster::test]
163    async fn test_promise_reject() {
164        let promise = Promise::<i32, &str>::reject("error");
165        assert_eq!(promise.await.unwrap_err(), "error");
166    }
167
168    #[pollster::test]
169    async fn test_promise_then() {
170        let promise = Promise::<i32, &str>::resolve(2).then(|x| async move { Ok(x * 3) });
171        assert_eq!(promise.await.unwrap(), 6);
172    }
173
174    #[pollster::test]
175    async fn test_promise_catch() {
176        let promise = Promise::<i32, &str>::reject("error").catch(|_| async { Ok(99) });
177        assert_eq!(promise.await.unwrap(), 99);
178    }
179
180    #[pollster::test]
181    async fn test_promise_finally() {
182        let flag = Arc::new(Mutex::new(false));
183        let flag_clone = Arc::clone(&flag);
184        let promise = Promise::<i32, &str>::resolve(5).finally(move || async move {
185            let mut flag = flag_clone.lock().unwrap();
186            *flag = true;
187        });
188        assert_eq!(promise.await.unwrap(), 5);
189        assert!(*flag.lock().unwrap());
190    }
191
192    #[pollster::test]
193    async fn test_promise_transform() {
194        let promise = Promise::<i32, &str>::resolve(10)
195            .transform(|result| async move { result.map(|v| v * 2) });
196        assert_eq!(promise.await.unwrap(), 20);
197    }
198
199    #[pollster::test]
200    async fn test_promise_all() {
201        let promise = Promise::all([
202            Promise::<i32, &str>::resolve(1),
203            Promise::<i32, &str>::resolve(2),
204            Promise::<i32, &str>::resolve(3),
205        ]);
206        assert_eq!(promise.await.unwrap(), vec![1, 2, 3]);
207    }
208
209    #[pollster::test]
210    async fn test_promise_any() {
211        let promise = Promise::any([
212            Promise::<i32, &str>::new(async {
213                delay(Duration::from_millis(10)).await;
214                Ok(1)
215            }),
216            Promise::<i32, &str>::new(async {
217                delay(Duration::from_millis(0)).await;
218                Ok(2)
219            }),
220            Promise::<i32, &str>::new(async {
221                delay(Duration::from_millis(30)).await;
222                Ok(3)
223            }),
224        ]);
225        assert_eq!(promise.await.unwrap(), 1);
226    }
227
228    #[pollster::test]
229    async fn test_promise_chain() {
230        let promise = Promise::<i32, String>::new(async {
231            delay(Duration::from_millis(10)).await;
232            Ok(1)
233        })
234        .then(|x| async move { Ok(x + 3) })
235        .transform(|result| async move { result.map(|v| v * 2) })
236        .catch(|err| async move {
237            println!("Caught an error: {}", err);
238            Ok(0)
239        })
240        .finally(|| async { println!("Done") });
241        assert_eq!(promise.await.unwrap(), 8);
242    }
243}