Skip to main content

ps_promise/methods/
any.rs

1use std::{
2    future::Future,
3    task::Poll::{Pending, Ready},
4};
5
6use crate::{Promise, PromiseRejection};
7
8impl<T, E> Promise<T, E>
9where
10    T: Send + Unpin + 'static,
11    E: PromiseRejection,
12{
13    pub fn any<I>(promises: I) -> Promise<T, Vec<E>>
14    where
15        I: IntoIterator<Item = Self>,
16    {
17        Promise::new(PromiseAny::from(promises))
18    }
19}
20
21impl<T, E> Future for PromiseAny<T, E>
22where
23    T: Send + Unpin + 'static,
24    E: PromiseRejection,
25{
26    type Output = Result<T, Vec<E>>;
27
28    fn poll(
29        self: std::pin::Pin<&mut Self>,
30        cx: &mut std::task::Context<'_>,
31    ) -> std::task::Poll<Self::Output> {
32        let this = self.get_mut();
33
34        let mut is_pending = false;
35
36        this.promises.iter_mut().for_each(|promise| {
37            if promise.pending(cx) {
38                is_pending = true;
39            }
40        });
41
42        if is_pending {
43            return Pending;
44        }
45
46        let mut errors = Vec::new();
47
48        for promise in &mut this.promises {
49            match promise.consume() {
50                Some(Ok(val)) => return Ready(Ok(val)),
51                Some(Err(err)) => errors.push(err),
52                None => unreachable!("We checked no Promise is pending."),
53            }
54        }
55
56        Ready(Err(errors))
57    }
58}
59
60pub struct PromiseAny<T, E>
61where
62    T: Send + Unpin + 'static,
63    E: PromiseRejection,
64{
65    promises: Vec<Promise<T, E>>,
66}
67
68impl<I, T, E> From<I> for PromiseAny<T, E>
69where
70    T: Send + Unpin + 'static,
71    E: PromiseRejection,
72    I: IntoIterator<Item = Promise<T, E>>,
73{
74    fn from(value: I) -> Self {
75        Self {
76            promises: value.into_iter().collect(),
77        }
78    }
79}
80
81#[cfg(test)]
82mod tests {
83    use std::task::{Context, Waker};
84
85    use crate::{Promise, PromiseRejection};
86
87    #[derive(thiserror::Error, Debug, Clone, Copy, Hash, PartialEq, Eq, PartialOrd, Ord)]
88    enum E {
89        #[error("Promise already consumed.")]
90        AlreadyConsumed,
91        #[error("Code: {0}")]
92        Code(i32),
93    }
94
95    impl PromiseRejection for E {
96        fn already_consumed() -> Self {
97            Self::AlreadyConsumed
98        }
99    }
100
101    fn cx() -> Context<'static> {
102        Context::from_waker(Waker::noop())
103    }
104
105    #[test]
106    fn empty() {
107        let mut all: Promise<(), Vec<E>> = Promise::any([]);
108
109        all.ready(&mut cx());
110
111        if let Promise::Rejected(v) = all {
112            assert!(v.is_empty(), "Result vector is not empty!");
113        } else {
114            panic!("Invalid state for empty input: {all:?}");
115        }
116    }
117
118    #[test]
119    fn resolving() {
120        let mut all = Promise::any([
121            Promise::new(async { Err(E::Code(1)) }),
122            Promise::new(async { Ok(2) }),
123            Promise::new(async { Err(E::Code(3)) }),
124        ]);
125
126        all.ready(&mut cx());
127
128        if let Promise::Resolved(v) = all {
129            assert_eq!(v, 2);
130        } else {
131            panic!("Expected Resolved(2), got {all:?}");
132        }
133    }
134
135    #[test]
136    fn rejecting() {
137        let mut all: Promise<(), Vec<E>> = Promise::any([
138            Promise::new(async { Err(E::Code(1)) }),
139            Promise::new(async { Err(E::Code(2)) }),
140            Promise::new(async { Err(E::Code(3)) }),
141        ]);
142
143        all.ready(&mut cx());
144
145        if let Promise::Rejected(v) = all {
146            assert_eq!(v, [1, 2, 3].map(E::Code));
147        } else {
148            panic!("Expected Rejected([1,2,3]), got {all:?}");
149        }
150    }
151}