1use std::borrow::BorrowMut;
4use std::pin::Pin;
5use std::task::{Context, Poll};
6
7use futures::stream::{FuturesUnordered, Stream};
8use futures::Future;
9use pin_project::pin_project;
10
11use crate::common::Passthrough;
12
13#[pin_project]
23pub struct FuturesUnorderedIter<T, F>
24where
25 F: Future,
26 T: Iterator<Item = F>,
27{
28 max_concurrent: usize,
29 tasks: T,
30 #[pin]
31 running_tasks: FuturesUnordered<F>,
32}
33
34#[pin_project]
43pub struct FuturesUnorderedBounded<F>
44where
45 F: Future,
46{
47 max_concurrent: usize,
48 queued_tasks: Vec<F>,
49 #[pin]
50 running_tasks: FuturesUnordered<F>,
51}
52
53impl<T, F> Passthrough<F> for FuturesUnorderedIter<T, F>
54where
55 T: Iterator<Item = F>,
56 F: Future,
57{
58 type FuturesHolder = FuturesUnordered<F>;
59
60 fn set_max_concurrent(&mut self, max_concurrent: usize) {
61 self.max_concurrent = max_concurrent;
62 }
63
64 fn borrow_inner(&self) -> &FuturesUnordered<F> {
65 &self.running_tasks
66 }
67
68 fn borrow_mut_inner(&mut self) -> &mut FuturesUnordered<F> {
69 self.running_tasks.borrow_mut()
70 }
71
72 fn into_inner(self) -> FuturesUnordered<F>
73 where
74 Self: Sized,
75 {
76 self.running_tasks
77 }
78}
79
80impl<F> Passthrough<F> for FuturesUnorderedBounded<F>
81where
82 F: Future,
83{
84 type FuturesHolder = FuturesUnordered<F>;
85
86 fn set_max_concurrent(&mut self, max_concurrent: usize) {
87 self.max_concurrent = max_concurrent;
88 }
89
90 fn borrow_inner(&self) -> &FuturesUnordered<F> {
91 &self.running_tasks
92 }
93
94 fn borrow_mut_inner(&mut self) -> &mut FuturesUnordered<F> {
95 self.running_tasks.borrow_mut()
96 }
97
98 fn into_inner(self) -> FuturesUnordered<F>
99 where
100 Self: Sized,
101 {
102 self.running_tasks
103 }
104}
105
106impl<T, F> Stream for FuturesUnorderedIter<T, F>
107where
108 T: Iterator<Item = F>,
109 F: Future,
110{
111 type Item = F::Output;
112
113 fn poll_next(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Option<Self::Item>> {
115 let mut this = self.project();
116 match this.running_tasks.as_mut().poll_next(cx) {
117 Poll::Ready(Some(value)) => {
118 while this.running_tasks.len() < *this.max_concurrent {
119 match this.tasks.next() {
120 Some(future) => this.running_tasks.push(future),
121 _ => break,
122 }
123 }
124 Poll::Ready(Some(value))
125 }
126 Poll::Ready(None) => Poll::Ready(None),
127 Poll::Pending => Poll::Pending,
128 }
129 }
130}
131
132impl<F> futures::stream::Stream for FuturesUnorderedBounded<F>
133where
134 F: Future,
135{
136 type Item = F::Output;
137
138 fn poll_next(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Option<Self::Item>> {
140 let mut this = self.project();
141 match this.running_tasks.as_mut().poll_next(cx) {
142 Poll::Ready(Some(value)) => {
143 while this.running_tasks.len() < *this.max_concurrent {
144 match this.queued_tasks.pop() {
145 Some(future) => this.running_tasks.push(future),
146 _ => break,
147 }
148 }
149 Poll::Ready(Some(value))
150 }
151 Poll::Ready(None) => Poll::Ready(None),
152 Poll::Pending => Poll::Pending,
153 }
154 }
155}
156
157impl<T, F> FuturesUnorderedIter<T, F>
158where
159 F: Future,
160 T: Iterator<Item = F>,
161{
162 pub fn new<I: IntoIterator<IntoIter = T>>(max_concurrent: usize, tasks: I) -> Self {
200 let running_tasks = FuturesUnordered::new();
201 assert!(max_concurrent > 0, "max_concurrent must be greater than 0");
202 let mut tasks = tasks.into_iter();
203 tasks
205 .borrow_mut()
206 .take(max_concurrent)
207 .for_each(|future| running_tasks.push(future));
208
209 Self {
210 max_concurrent,
211 tasks,
212 running_tasks,
213 }
214 }
215}
216
217impl<F> FuturesUnorderedBounded<F>
218where
219 F: Future,
220{
221 pub fn new(max_concurrent: usize) -> Self {
257 let running_tasks = FuturesUnordered::new();
258 assert!(max_concurrent > 0, "max_concurrent must be greater than 0");
259 Self {
260 max_concurrent,
261 queued_tasks: Vec::new(),
262 running_tasks,
263 }
264 }
265
266 pub fn push(&mut self, fut: F) {
272 match self.borrow_inner().len() < self.max_concurrent {
273 true => self.borrow_mut_inner().push(fut),
274 false => self.queued_tasks.push(fut),
275 }
276 }
277
278 pub fn borrow_mut_queue(&mut self) -> &mut Vec<F> {
280 &mut self.queued_tasks
281 }
282
283 pub fn borrow_queue(&self) -> &Vec<F> {
285 &self.queued_tasks
286 }
287
288 pub fn clear_queue(&mut self) {
292 self.queued_tasks.clear();
293 self.borrow_mut_inner().clear();
294 }
295
296 pub fn len_queue(&self) -> usize {
298 self.queued_tasks.len()
299 }
300}
301
302#[cfg(test)]
303mod tests {
304 use std::{
305 cmp,
306 sync::{atomic::AtomicU8, Arc},
307 time::Duration,
308 };
309
310 use futures::StreamExt;
311
312 use super::*;
313
314 fn create_tasks() -> Vec<impl Future<Output = u64>> {
315 const MAX_FUTURES: u64 = 100;
316 let tasks = (0..MAX_FUTURES).map(|i| dummy(i)).collect::<Vec<_>>();
317 tasks
318 }
319
320 async fn dummy(val: u64) -> u64 {
323 val
324 }
325
326 async fn dummy_checked(val: Arc<AtomicU8>) -> u8 {
329 val.fetch_add(1, std::sync::atomic::Ordering::Acquire);
330 let max = val.load(std::sync::atomic::Ordering::Relaxed);
331 tokio::time::sleep(Duration::from_nanos(100)).await;
332 val.fetch_sub(1, std::sync::atomic::Ordering::Release);
333 max
334 }
335
336 #[tokio::test]
337 async fn test_futures_unordered_iter() {
338 let tasks = create_tasks();
339 let max_concurrent: u64 = 10;
340
341 let mut fut_iter = FuturesUnorderedIter::new(max_concurrent as usize, tasks);
342
343 let mut counter = 0;
344 while let Some(result) = fut_iter.next().await {
345 assert_eq!(result, counter);
346 counter += 1;
347 assert!(fut_iter.borrow_inner().len() <= max_concurrent as usize);
349 }
350 }
351
352 #[tokio::test]
353 async fn test_futures_unordered() {
354 let mut fut_iter = FuturesUnorderedBounded::new(10);
355 for i in 0..100 {
356 fut_iter.push(dummy(i as u64));
357 }
358 let result = fut_iter.next().await;
360
361 assert!(result.is_some());
362 assert_eq!(fut_iter.len_queue(), 89); assert_eq!(fut_iter.borrow_inner().len(), 10);
366 let _result = fut_iter.next().await;
367 assert_eq!(fut_iter.len_queue(), 88);
369 assert_eq!(fut_iter.borrow_inner().len(), 10);
371 fut_iter.clear_queue();
373 assert_eq!(fut_iter.len_queue(), 0);
374 assert!(fut_iter.borrow_inner().is_empty());
375 assert_eq!(fut_iter.borrow_inner().len(), 0);
376
377 let result = fut_iter.next().await;
378 assert!(result.is_none());
379 for i in 0..100 {
381 fut_iter.push(dummy(i as u64));
382 }
383 assert_eq!(fut_iter.len_queue(), 90);
384 assert_eq!(fut_iter.borrow_inner().len(), 10);
385
386 let result = fut_iter.next().await;
387 assert!(result.is_some());
388 assert_eq!(fut_iter.len_queue(), 89);
390 assert_eq!(fut_iter.borrow_inner().len(), 10);
392 assert_eq!(fut_iter.borrow_queue().len(), 89);
394 assert_eq!(fut_iter.borrow_mut_queue().len(), 89);
395
396 assert_eq!(fut_iter.into_inner().len(), 10);
397 }
398
399 #[tokio::test]
400 async fn test_max_concurrent() {
401 let test_value = Arc::new(AtomicU8::new(0));
403 let tasks = (0..50)
404 .map(|_| dummy_checked(Arc::clone(&test_value)))
405 .collect::<Vec<_>>();
406
407 let mut fut_iter = FuturesUnorderedIter::new(10, tasks);
408 let mut max_so_far = 0;
409 let mut count = 0;
410 while let Some(max_res) = fut_iter.next().await {
411 max_so_far = cmp::max(max_so_far, max_res);
412 if count > 10 {
414 fut_iter.set_max_concurrent(20);
415 break;
416 }
417 count += 1;
418 }
419 assert_eq!(max_so_far, 10);
420 let max = fut_iter.collect::<Vec<_>>().await.into_iter().max();
422 assert_eq!(max, Some(20));
423
424 let test_value = Arc::new(AtomicU8::new(0));
426 let mut fut_iter = FuturesUnorderedBounded::new(10);
427 for _ in 0..50 {
428 fut_iter.push(dummy_checked(Arc::clone(&test_value)));
429 }
430 let mut max_so_far = 0;
431 while let Some(max_res) = fut_iter.next().await {
432 max_so_far = cmp::max(max_so_far, max_res);
433 }
434 assert_eq!(max_so_far, 10);
435 fut_iter.set_max_concurrent(20);
437 for _ in 0..50 {
438 fut_iter.push(dummy_checked(Arc::clone(&test_value)));
439 }
440 while let Some(max_res) = fut_iter.next().await {
441 max_so_far = cmp::max(max_so_far, max_res);
442 }
443 assert_eq!(max_so_far, 20);
444 }
445
446 #[tokio::test]
447 async fn test_iter_inner() {
448 let max_concurrent: u64 = 10;
449 let tasks = (0..100).map(|i| dummy(i)).collect::<Vec<_>>();
450
451 let mut fut_iter = FuturesUnorderedIter::new(max_concurrent as usize, tasks);
452
453 assert_eq!(fut_iter.borrow_inner().len(), 10);
454 assert_eq!(fut_iter.borrow_inner().is_empty(), false);
455 {
456 let inner = fut_iter.borrow_inner();
457 assert_eq!(inner.len(), 10);
458 }
459 {
460 let inner = fut_iter.borrow_mut_inner();
461 assert_eq!(inner.len(), 10);
462 }
463 fut_iter.borrow_mut_inner().push(dummy(123));
465 assert_eq!(fut_iter.borrow_inner().len(), 11);
466
467 fut_iter.borrow_mut_inner().clear();
469 assert_eq!(fut_iter.borrow_inner().len(), 0);
470
471 let inner = fut_iter.into_inner();
472 assert_eq!(inner.len(), 0);
473 }
474}