Skip to main content

ps_promise/methods/
all.rs

1use std::{
2    future::Future,
3    task::Poll::{Pending, Ready},
4};
5
6use crate::{Promise, PromiseRejection};
7
8impl<T, E> Promise<T, E>
9where
10    T: Send + Unpin + 'static,
11    E: PromiseRejection,
12{
13    pub fn all<I>(promises: I) -> Promise<Vec<T>, E>
14    where
15        I: IntoIterator<Item = Self>,
16    {
17        Promise::new(PromiseAll::from(promises))
18    }
19}
20
21impl<T, E> Future for PromiseAll<T, E>
22where
23    T: Send + Unpin + 'static,
24    E: PromiseRejection,
25{
26    type Output = Result<Vec<T>, E>;
27
28    fn poll(
29        self: std::pin::Pin<&mut Self>,
30        cx: &mut std::task::Context<'_>,
31    ) -> std::task::Poll<Self::Output> {
32        let this = self.get_mut();
33
34        // Phase 1: drive all promises to completion.
35        let mut is_pending = false;
36        let mut rejection_idx = None;
37
38        for (idx, promise) in this.promises.iter_mut().enumerate() {
39            if promise.pending(cx) {
40                is_pending = true;
41            }
42
43            if rejection_idx.is_none() && promise.is_rejected() {
44                rejection_idx = Some(idx);
45            }
46        }
47
48        if is_pending {
49            return Pending;
50        }
51
52        if let Some(err) = rejection_idx
53            .and_then(|idx| this.promises.get_mut(idx))
54            .and_then(Promise::consume)
55            .and_then(Result::err)
56        {
57            return Ready(Err(err));
58        }
59
60        // Phase 2: collect values
61        let mut values = Vec::new();
62
63        for promise in &mut this.promises {
64            match promise.consume() {
65                Some(Ok(val)) => values.push(val),
66                Some(Err(err)) => return Ready(Err(err)),
67                None => unreachable!("All promises are settled."),
68            }
69        }
70
71        Ready(Ok(values))
72    }
73}
74
75pub struct PromiseAll<T, E>
76where
77    T: Send + Unpin + 'static,
78    E: PromiseRejection,
79{
80    promises: Vec<Promise<T, E>>,
81}
82
83impl<I, T, E> From<I> for PromiseAll<T, E>
84where
85    T: Send + Unpin + 'static,
86    E: PromiseRejection,
87    I: IntoIterator<Item = Promise<T, E>>,
88{
89    fn from(value: I) -> Self {
90        Self {
91            promises: value.into_iter().collect(),
92        }
93    }
94}
95
96#[cfg(test)]
97mod tests {
98    use std::task::{Context, Waker};
99
100    use crate::{Promise, PromiseRejection};
101
102    #[derive(thiserror::Error, Debug, Clone, Copy, Hash, PartialEq, Eq, PartialOrd, Ord)]
103    enum E {
104        #[error("Promise already consumed.")]
105        AlreadyConsumed,
106        #[error("Code: {0}")]
107        Code(i32),
108    }
109
110    impl PromiseRejection for E {
111        fn already_consumed() -> Self {
112            Self::AlreadyConsumed
113        }
114    }
115
116    fn cx() -> Context<'static> {
117        Context::from_waker(Waker::noop())
118    }
119
120    #[test]
121    fn empty() {
122        let mut all: Promise<Vec<()>, E> = Promise::all([]);
123        all.ready(&mut cx());
124
125        match all {
126            Promise::Resolved(v) => assert!(v.is_empty()),
127            other => panic!("expected Resolved(vec![]), got {other:?}"),
128        }
129    }
130
131    #[test]
132    fn all_resolve() {
133        let mut all = Promise::all([
134            Promise::new(async { Ok::<_, E>(1) }),
135            Promise::new(async { Ok(2) }),
136            Promise::new(async { Ok(3) }),
137        ]);
138
139        all.ready(&mut cx());
140
141        match all {
142            Promise::Resolved(v) => assert_eq!(v, vec![1, 2, 3]),
143            other => panic!("expected Resolved([1,2,3]), got {other:?}"),
144        }
145    }
146
147    #[test]
148    fn single_rejection() {
149        let mut all: Promise<Vec<i32>, E> = Promise::all([
150            Promise::new(async { Ok(1) }),
151            Promise::new(async { Err(E::Code(2)) }),
152            Promise::new(async { Ok(3) }),
153        ]);
154
155        all.ready(&mut cx());
156
157        match all {
158            Promise::Rejected(E::Code(2)) => {}
159            other => panic!("expected Rejected(Code(2)), got {other:?}"),
160        }
161    }
162
163    #[test]
164    fn returns_first_error() {
165        let mut all: Promise<Vec<i32>, E> = Promise::all([
166            Promise::new(async { Err(E::Code(1)) }),
167            Promise::new(async { Ok(99) }),
168            Promise::new(async { Err(E::Code(2)) }),
169            Promise::new(async { Err(E::Code(3)) }),
170        ]);
171
172        all.ready(&mut cx());
173        assert_eq!(all.consume(), Some(Err(E::Code(1))));
174
175        all.ready(&mut cx());
176        assert_eq!(all.consume(), Some(Err(E::AlreadyConsumed)));
177    }
178
179    #[test]
180    fn repoll_after_success_yields_already_consumed() {
181        let mut all: Promise<Vec<i32>, E> = Promise::all([Promise::new(async { Ok(1) })]);
182
183        all.ready(&mut cx());
184        assert_eq!(all.consume(), Some(Ok(vec![1])));
185
186        all.ready(&mut cx());
187        assert_eq!(all.consume(), Some(Err(E::AlreadyConsumed)));
188    }
189
190    #[test]
191    fn all_rejected() {
192        let mut all: Promise<Vec<i32>, E> = Promise::all([
193            Promise::new(async { Err(E::Code(10)) }),
194            Promise::new(async { Err(E::Code(20)) }),
195        ]);
196
197        all.ready(&mut cx());
198        assert_eq!(all.consume(), Some(Err(E::Code(10))));
199
200        all.ready(&mut cx());
201        assert_eq!(all.consume(), Some(Err(E::AlreadyConsumed)));
202    }
203}