futures_concurrency/future/try_join/
vec.rs1use super::TryJoin as TryJoinTrait;
2use crate::utils::{FutureVec, OutputVec, PollVec, WakerVec};
3
4#[cfg(all(feature = "alloc", not(feature = "std")))]
5use alloc::vec::Vec;
6
7use core::fmt;
8use core::future::{Future, IntoFuture};
9use core::mem::ManuallyDrop;
10use core::ops::DerefMut;
11use core::pin::Pin;
12use core::task::{Context, Poll};
13
14use pin_project::{pin_project, pinned_drop};
15
16#[must_use = "futures do nothing unless you `.await` or poll them"]
24#[pin_project(PinnedDrop)]
25pub struct TryJoin<Fut, T, E>
26where
27 Fut: Future<Output = Result<T, E>>,
28{
29 consumed: bool,
31 pending: usize,
33 items: OutputVec<T>,
35 wakers: WakerVec,
38 state: PollVec,
40 #[pin]
41 futures: FutureVec<Fut>,
43}
44
45impl<Fut, T, E> TryJoin<Fut, T, E>
46where
47 Fut: Future<Output = Result<T, E>>,
48{
49 #[inline]
50 pub(crate) fn new(futures: Vec<Fut>) -> Self {
51 let len = futures.len();
52 Self {
53 consumed: false,
54 pending: len,
55 items: OutputVec::uninit(len),
56 wakers: WakerVec::new(len),
57 state: PollVec::new_pending(len),
58 futures: FutureVec::new(futures),
59 }
60 }
61}
62
63impl<Fut, T, E> TryJoinTrait for Vec<Fut>
64where
65 Fut: IntoFuture<Output = Result<T, E>>,
66{
67 type Output = Vec<T>;
68 type Error = E;
69 type Future = TryJoin<Fut::IntoFuture, T, E>;
70
71 fn try_join(self) -> Self::Future {
72 TryJoin::new(self.into_iter().map(IntoFuture::into_future).collect())
73 }
74}
75
76impl<Fut, T, E> fmt::Debug for TryJoin<Fut, T, E>
77where
78 Fut: Future<Output = Result<T, E>> + fmt::Debug,
79{
80 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
81 f.debug_list().entries(self.state.iter()).finish()
82 }
83}
84
85impl<Fut, T, E> Future for TryJoin<Fut, T, E>
86where
87 Fut: Future<Output = Result<T, E>>,
88{
89 type Output = Result<Vec<T>, E>;
90
91 #[inline]
92 fn poll(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> {
93 let this = self.project();
94
95 assert!(
96 !*this.consumed,
97 "Futures must not be polled after completing"
98 );
99
100 let mut readiness = this.wakers.readiness();
101 readiness.set_waker(cx.waker());
102 if *this.pending != 0 && !readiness.any_ready() {
103 return Poll::Pending;
105 }
106
107 for (i, mut fut) in this.futures.iter().enumerate() {
109 if this.state[i].is_pending() && readiness.clear_ready(i) {
110 #[allow(clippy::drop_non_drop)]
112 drop(readiness);
113
114 let mut cx = Context::from_waker(this.wakers.get(i).unwrap());
116
117 if let Poll::Ready(value) = unsafe {
120 fut.as_mut()
121 .map_unchecked_mut(|t| t.deref_mut())
122 .poll(&mut cx)
123 } {
124 *this.pending -= 1;
125
126 match value {
128 Ok(value) => {
129 this.items.write(i, value);
130
131 this.state[i].set_ready();
136 unsafe { ManuallyDrop::drop(fut.get_unchecked_mut()) };
137 }
138 Err(err) => {
139 *this.consumed = true;
141
142 this.state[i].set_none();
148 unsafe { ManuallyDrop::drop(fut.get_unchecked_mut()) };
149
150 return Poll::Ready(Err(err));
151 }
152 }
153 }
154
155 readiness = this.wakers.readiness();
157 }
158 }
159
160 if *this.pending == 0 {
162 *this.consumed = true;
164 for state in this.state.iter_mut() {
165 debug_assert!(
166 state.is_ready(),
167 "Future should have reached a `Ready` state"
168 );
169 state.set_none();
170 }
171
172 Poll::Ready(Ok(unsafe { this.items.take() }))
175 } else {
176 Poll::Pending
177 }
178 }
179}
180
181#[pinned_drop]
183impl<Fut, T, E> PinnedDrop for TryJoin<Fut, T, E>
184where
185 Fut: Future<Output = Result<T, E>>,
186{
187 fn drop(self: Pin<&mut Self>) {
188 let mut this = self.project();
189
190 for i in this.state.ready_indexes() {
192 unsafe { this.items.drop(i) };
195 }
196
197 for i in this.state.pending_indexes() {
199 unsafe { this.futures.as_mut().drop(i) };
202 }
203 }
204}
205
206#[cfg(test)]
207mod test {
208 use super::*;
209 use alloc::vec;
210 use core::future;
211
212 #[test]
213 fn all_ok() {
214 futures_lite::future::block_on(async {
215 let res: Result<_, ()> = vec![future::ready(Ok("hello")), future::ready(Ok("world"))]
216 .try_join()
217 .await;
218 assert_eq!(res.unwrap(), ["hello", "world"]);
219 })
220 }
221
222 #[test]
223 fn empty() {
224 futures_lite::future::block_on(async {
225 let data: Vec<future::Ready<Result<(), ()>>> = vec![];
226 let res = data.try_join().await;
227 assert_eq!(res.unwrap(), vec![]);
228 });
229 }
230
231 #[test]
232 fn one_err() {
233 futures_lite::future::block_on(async {
234 let res: Result<_, _> = vec![future::ready(Ok("hello")), future::ready(Err("oh no"))]
235 .try_join()
236 .await;
237 assert_eq!(res.unwrap_err(), "oh no");
238 });
239 }
240}