limited_join/
lib.rs

1//! A crate providing a future similar to
2//! [future's join](https://docs.rs/futures/latest/futures/future/fn.join.html) but with a limited
3//! amount of concurrency.
4//!
5//! # Example
6//!
7//! ```rust
8//! # tokio_test::block_on(async {
9//! # use limited_join::*;
10//! # mod download {
11//! #     pub async fn download_file(url: String) {}
12//! # }
13//! # let files_to_download: Vec<String> = vec![];
14//! // Pretend we have a ton of files we want to download, but don't want to
15//! // overwhelm the server.
16//! let futures = files_to_download.into_iter().map(download::download_file);
17//!
18//! // Let's limit the number of concurrent downloads to 4, and wait for all
19//! // the files to download.
20//! limited_join::join(futures, 4).await;
21//!
22//! # });
23//! ```
24use std::{
25    future::Future,
26    pin::Pin,
27    task::{Context, Poll},
28};
29
30/// The [`Future`] behind the [`join`] function.
31pub struct LimitedJoin<Fut>
32where
33    Fut: Future,
34{
35    inner: Pin<Box<[MaybeCompleted<Fut>]>>,
36    /// How many futures can be concurrently pending.
37    concurrency: usize,
38}
39
40/// Returns a future that acts as a [join](https://docs.rs/futures/latest/futures/future/fn.join.html)
41/// of multiple futures, but with a limit on how many futures can be running at once.
42///
43/// # Example
44/// ```rust
45/// # tokio_test::block_on(async {
46/// use std::time::{Duration, Instant};
47/// use tokio::time::sleep;
48///
49/// let then = Instant::now();
50/// let futures = std::iter::repeat(Duration::from_millis(100))
51///     .map(|duration| async move { sleep(duration).await })
52///     .take(4);
53///
54/// limited_join::join(futures, 2).await;
55///
56/// // Ensure all futures completed in roughly 200ms as we're processing only 2 at a time.
57/// assert!(then.elapsed().as_millis() - 200 < 10);
58/// # });
59/// ```
60pub fn join<Fut>(futures: impl IntoIterator<Item = Fut>, concurrency: usize) -> LimitedJoin<Fut>
61where
62    Fut: Future,
63{
64    let futures = futures
65        .into_iter()
66        .map(MaybeCompleted::InProgress)
67        .collect::<Vec<_>>()
68        .into_boxed_slice();
69    LimitedJoin {
70        inner: futures.into(),
71        concurrency,
72    }
73}
74
75impl<Fut> Future for LimitedJoin<Fut>
76where
77    Fut: Future,
78{
79    type Output = Vec<Fut::Output>;
80
81    fn poll(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> {
82        // SAFETY: This is safe because we never move the inner futures.
83        let this = unsafe { Pin::get_unchecked_mut(self) };
84        // SAFETY: This is safe because we never move any of the futures.
85        let states = unsafe { Pin::get_unchecked_mut(this.inner.as_mut()) };
86
87        let mut remaining = states.iter().filter(|state| state.is_in_progress()).count();
88        let mut to_poll = this.concurrency.min(remaining);
89
90        let mut polled = 0;
91        let mut index = 0;
92
93        while polled < to_poll && index < states.len() {
94            let state = &mut states[index];
95
96            // Ensure that the future is ready to be polled, either by being new or by having been
97            // woken up.
98            if !state.is_in_progress() {
99                index += 1;
100                continue;
101            }
102
103            // SAFETY: This is all behind a Pin and can't be unpinned so we know it's in the same
104            // location in memory.
105            let res = unsafe { Pin::new_unchecked(state).poll(cx) };
106
107            if let Poll::Ready(output) = res {
108                states[index] = MaybeCompleted::Completed(output);
109                remaining -= 1;
110
111                // We've completed a future, so we can poll another one.
112                to_poll += 1;
113            }
114
115            polled += 1;
116            index += 1;
117        }
118
119        if remaining == 0 {
120            Poll::Ready(states.iter_mut().map(|state| state.take()).collect())
121        } else {
122            Poll::Pending
123        }
124    }
125}
126
127enum MaybeCompleted<Fut: Future> {
128    InProgress(Fut),
129    Completed(Fut::Output),
130    Drained,
131}
132
133impl<Fut: Future> MaybeCompleted<Fut> {
134    fn is_in_progress(&self) -> bool {
135        matches!(self, Self::InProgress { .. })
136    }
137
138    fn take(&mut self) -> Fut::Output {
139        match std::mem::replace(self, MaybeCompleted::Drained) {
140            Self::Completed(output) => output,
141            Self::InProgress(_) => panic!("attempt to get output of incomplete future"),
142            Self::Drained => panic!("attempt to get output of drained future"),
143        }
144    }
145
146    unsafe fn poll(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Fut::Output> {
147        let this = self.as_mut();
148        let this = this.get_unchecked_mut();
149        match this {
150            Self::InProgress(future) => Pin::new_unchecked(future).poll(cx),
151            _ => unreachable!("attempted to poll a complete or drained future"),
152        }
153    }
154}
155
156#[cfg(test)]
157mod tests {
158    use std::{
159        sync::{
160            atomic::{AtomicBool, Ordering},
161            Arc,
162        },
163        time::Duration,
164    };
165
166    use tokio::time::sleep;
167
168    use super::*;
169
170    #[tokio::test]
171    async fn test_not_above_limit() {
172        let joined = join(
173            [
174                sleep(Duration::from_millis(10)),
175                sleep(Duration::from_millis(20)),
176            ],
177            10,
178        );
179
180        let timeout = tokio::time::timeout(Duration::from_millis(30), joined);
181        timeout.await.expect("future timed out before completion");
182    }
183
184    #[tokio::test]
185    async fn test_above_limit_no_concurrency() {
186        let completed = Arc::new(AtomicBool::new(false));
187        let run = |expected: bool| {
188            let completed = completed.clone();
189            async move {
190                let loaded = completed.load(Ordering::SeqCst);
191                assert_eq!(loaded, expected);
192                sleep(Duration::from_millis(10)).await;
193                completed.store(true, Ordering::SeqCst);
194            }
195        };
196
197        join([run(false), run(true)], 1).await;
198    }
199
200    #[tokio::test]
201    async fn test_above_limit() {
202        let (tx, rx) = std::sync::mpsc::channel();
203        let record = |id: usize, millis: u64| {
204            let tx = tx.clone();
205            async move {
206                tx.send(format!("s{id}")).unwrap();
207                sleep(Duration::from_millis(millis)).await;
208                tx.send(format!("e{id}")).unwrap();
209            }
210        };
211
212        join(
213            [record(0, 10), record(1, 25), record(2, 50), record(3, 50)],
214            2,
215        )
216        .await;
217
218        let mut order = rx.into_iter();
219
220        // First two futures are polled concurrently.
221        assert_eq!("s0", order.next().unwrap());
222        assert_eq!("s1", order.next().unwrap());
223
224        // Next the first future resolves, causing use to poll the third.
225        assert_eq!("e0", order.next().unwrap());
226        assert_eq!("s2", order.next().unwrap());
227
228        // Our second future has now resolved, so we can start the last future.
229        assert_eq!("e1", order.next().unwrap());
230        assert_eq!("s3", order.next().unwrap());
231
232        // Finally, we wait for the last futures to resolve.
233        assert_eq!("e2", order.next().unwrap());
234        assert_eq!("e3", order.next().unwrap());
235    }
236}