futures_concurrency/future/try_join/
array.rs1use super::TryJoin as TryJoinTrait;
2use crate::utils::{FutureArray, OutputArray, PollArray, WakerArray};
3
4use core::fmt;
5use core::future::{Future, IntoFuture};
6use core::mem::ManuallyDrop;
7use core::ops::DerefMut;
8use core::pin::Pin;
9use core::task::{Context, Poll};
10
11use pin_project::{pin_project, pinned_drop};
12
13#[must_use = "futures do nothing unless you `.await` or poll them"]
21#[pin_project(PinnedDrop)]
22pub struct TryJoin<Fut, T, E, const N: usize>
23where
24 Fut: Future<Output = Result<T, E>>,
25{
26 consumed: bool,
28 pending: usize,
30 items: OutputArray<T, N>,
32 wakers: WakerArray<N>,
35 state: PollArray<N>,
37 #[pin]
38 futures: FutureArray<Fut, N>,
40}
41
42impl<Fut, T, E, const N: usize> TryJoin<Fut, T, E, N>
43where
44 Fut: Future<Output = Result<T, E>>,
45{
46 #[inline]
47 pub(crate) fn new(futures: [Fut; N]) -> Self {
48 Self {
49 consumed: false,
50 pending: N,
51 items: OutputArray::uninit(),
52 wakers: WakerArray::new(),
53 state: PollArray::new_pending(),
54 futures: FutureArray::new(futures),
55 }
56 }
57}
58
59impl<Fut, T, E, const N: usize> TryJoinTrait for [Fut; N]
60where
61 Fut: IntoFuture<Output = Result<T, E>>,
62{
63 type Output = [T; N];
64 type Error = E;
65 type Future = TryJoin<Fut::IntoFuture, T, E, N>;
66
67 fn try_join(self) -> Self::Future {
68 TryJoin::new(self.map(IntoFuture::into_future))
69 }
70}
71
72impl<Fut, T, E, const N: usize> fmt::Debug for TryJoin<Fut, T, E, N>
73where
74 Fut: Future<Output = Result<T, E>> + fmt::Debug,
75{
76 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
77 f.debug_list().entries(self.state.iter()).finish()
78 }
79}
80
81impl<Fut, T, E, const N: usize> Future for TryJoin<Fut, T, E, N>
82where
83 Fut: Future<Output = Result<T, E>>,
84{
85 type Output = Result<[T; N], E>;
86
87 #[inline]
88 fn poll(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> {
89 let this = self.project();
90
91 assert!(
92 !*this.consumed,
93 "Futures must not be polled after completing"
94 );
95
96 let mut readiness = this.wakers.readiness();
97 readiness.set_waker(cx.waker());
98 if *this.pending != 0 && !readiness.any_ready() {
99 return Poll::Pending;
101 }
102
103 for (i, mut fut) in this.futures.iter().enumerate() {
105 if this.state[i].is_pending() && readiness.clear_ready(i) {
106 #[allow(clippy::drop_non_drop)]
108 drop(readiness);
109
110 let mut cx = Context::from_waker(this.wakers.get(i).unwrap());
112
113 if let Poll::Ready(value) = unsafe {
116 fut.as_mut()
117 .map_unchecked_mut(|t| t.deref_mut())
118 .poll(&mut cx)
119 } {
120 *this.pending -= 1;
121
122 match value {
124 Ok(value) => {
125 this.items.write(i, value);
126
127 this.state[i].set_ready();
132 unsafe { ManuallyDrop::drop(fut.get_unchecked_mut()) };
133 }
134 Err(err) => {
135 *this.consumed = true;
137
138 this.state[i].set_none();
144 unsafe { ManuallyDrop::drop(fut.get_unchecked_mut()) };
145
146 return Poll::Ready(Err(err));
147 }
148 }
149 }
150
151 readiness = this.wakers.readiness();
153 }
154 }
155
156 if *this.pending == 0 {
158 *this.consumed = true;
160
161 debug_assert!(this.state.iter().all(|entry| entry.is_ready()));
164 this.state.set_all_none();
165 Poll::Ready(Ok(unsafe { this.items.take() }))
166 } else {
167 Poll::Pending
168 }
169 }
170}
171
172#[pinned_drop]
174impl<Fut, T, E, const N: usize> PinnedDrop for TryJoin<Fut, T, E, N>
175where
176 Fut: Future<Output = Result<T, E>>,
177{
178 fn drop(self: Pin<&mut Self>) {
179 let mut this = self.project();
180
181 for i in this.state.ready_indexes() {
183 unsafe { this.items.drop(i) };
186 }
187
188 for i in this.state.pending_indexes() {
190 unsafe { this.futures.as_mut().drop(i) };
193 }
194 }
195}
196
197#[cfg(test)]
198mod test {
199 use super::*;
200 use core::future;
201
202 #[test]
203 fn all_ok() {
204 futures_lite::future::block_on(async {
205 let res: Result<_, ()> = [future::ready(Ok("hello")), future::ready(Ok("world"))]
206 .try_join()
207 .await;
208 assert_eq!(res.unwrap(), ["hello", "world"]);
209 })
210 }
211
212 #[test]
213 fn empty() {
214 futures_lite::future::block_on(async {
215 let data: [future::Ready<Result<(), ()>>; 0] = [];
216 let res = data.try_join().await;
217 assert_eq!(res.unwrap(), []);
218 });
219 }
220
221 #[test]
222 fn one_err() {
223 futures_lite::future::block_on(async {
224 let res: Result<_, _> = [future::ready(Ok("hello")), future::ready(Err("oh no"))]
225 .try_join()
226 .await;
227 assert_eq!(res.unwrap_err(), "oh no");
228 });
229 }
230}