1use std::{
25 future::Future,
26 pin::Pin,
27 task::{Context, Poll},
28};
29
30pub struct LimitedJoin<Fut>
32where
33 Fut: Future,
34{
35 inner: Pin<Box<[MaybeCompleted<Fut>]>>,
36 concurrency: usize,
38}
39
40pub fn join<Fut>(futures: impl IntoIterator<Item = Fut>, concurrency: usize) -> LimitedJoin<Fut>
61where
62 Fut: Future,
63{
64 let futures = futures
65 .into_iter()
66 .map(MaybeCompleted::InProgress)
67 .collect::<Vec<_>>()
68 .into_boxed_slice();
69 LimitedJoin {
70 inner: futures.into(),
71 concurrency,
72 }
73}
74
75impl<Fut> Future for LimitedJoin<Fut>
76where
77 Fut: Future,
78{
79 type Output = Vec<Fut::Output>;
80
81 fn poll(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> {
82 let this = unsafe { Pin::get_unchecked_mut(self) };
84 let states = unsafe { Pin::get_unchecked_mut(this.inner.as_mut()) };
86
87 let mut remaining = states.iter().filter(|state| state.is_in_progress()).count();
88 let mut to_poll = this.concurrency.min(remaining);
89
90 let mut polled = 0;
91 let mut index = 0;
92
93 while polled < to_poll && index < states.len() {
94 let state = &mut states[index];
95
96 if !state.is_in_progress() {
99 index += 1;
100 continue;
101 }
102
103 let res = unsafe { Pin::new_unchecked(state).poll(cx) };
106
107 if let Poll::Ready(output) = res {
108 states[index] = MaybeCompleted::Completed(output);
109 remaining -= 1;
110
111 to_poll += 1;
113 }
114
115 polled += 1;
116 index += 1;
117 }
118
119 if remaining == 0 {
120 Poll::Ready(states.iter_mut().map(|state| state.take()).collect())
121 } else {
122 Poll::Pending
123 }
124 }
125}
126
127enum MaybeCompleted<Fut: Future> {
128 InProgress(Fut),
129 Completed(Fut::Output),
130 Drained,
131}
132
133impl<Fut: Future> MaybeCompleted<Fut> {
134 fn is_in_progress(&self) -> bool {
135 matches!(self, Self::InProgress { .. })
136 }
137
138 fn take(&mut self) -> Fut::Output {
139 match std::mem::replace(self, MaybeCompleted::Drained) {
140 Self::Completed(output) => output,
141 Self::InProgress(_) => panic!("attempt to get output of incomplete future"),
142 Self::Drained => panic!("attempt to get output of drained future"),
143 }
144 }
145
146 unsafe fn poll(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Fut::Output> {
147 let this = self.as_mut();
148 let this = this.get_unchecked_mut();
149 match this {
150 Self::InProgress(future) => Pin::new_unchecked(future).poll(cx),
151 _ => unreachable!("attempted to poll a complete or drained future"),
152 }
153 }
154}
155
156#[cfg(test)]
157mod tests {
158 use std::{
159 sync::{
160 atomic::{AtomicBool, Ordering},
161 Arc,
162 },
163 time::Duration,
164 };
165
166 use tokio::time::sleep;
167
168 use super::*;
169
170 #[tokio::test]
171 async fn test_not_above_limit() {
172 let joined = join(
173 [
174 sleep(Duration::from_millis(10)),
175 sleep(Duration::from_millis(20)),
176 ],
177 10,
178 );
179
180 let timeout = tokio::time::timeout(Duration::from_millis(30), joined);
181 timeout.await.expect("future timed out before completion");
182 }
183
184 #[tokio::test]
185 async fn test_above_limit_no_concurrency() {
186 let completed = Arc::new(AtomicBool::new(false));
187 let run = |expected: bool| {
188 let completed = completed.clone();
189 async move {
190 let loaded = completed.load(Ordering::SeqCst);
191 assert_eq!(loaded, expected);
192 sleep(Duration::from_millis(10)).await;
193 completed.store(true, Ordering::SeqCst);
194 }
195 };
196
197 join([run(false), run(true)], 1).await;
198 }
199
200 #[tokio::test]
201 async fn test_above_limit() {
202 let (tx, rx) = std::sync::mpsc::channel();
203 let record = |id: usize, millis: u64| {
204 let tx = tx.clone();
205 async move {
206 tx.send(format!("s{id}")).unwrap();
207 sleep(Duration::from_millis(millis)).await;
208 tx.send(format!("e{id}")).unwrap();
209 }
210 };
211
212 join(
213 [record(0, 10), record(1, 25), record(2, 50), record(3, 50)],
214 2,
215 )
216 .await;
217
218 let mut order = rx.into_iter();
219
220 assert_eq!("s0", order.next().unwrap());
222 assert_eq!("s1", order.next().unwrap());
223
224 assert_eq!("e0", order.next().unwrap());
226 assert_eq!("s2", order.next().unwrap());
227
228 assert_eq!("e1", order.next().unwrap());
230 assert_eq!("s3", order.next().unwrap());
231
232 assert_eq!("e2", order.next().unwrap());
234 assert_eq!("e3", order.next().unwrap());
235 }
236}