futures_concurrency/concurrent_stream/
take.rs

1use pin_project::pin_project;
2
3use super::{ConcurrentStream, Consumer, ConsumerState};
4use core::future::Future;
5use core::num::NonZeroUsize;
6use core::pin::Pin;
7
8/// A concurrent iterator that only iterates over the first `n` iterations of `iter`.
9///
10/// This `struct` is created by the [`take`] method on [`ConcurrentStream`]. See its
11/// documentation for more.
12///
13/// [`take`]: ConcurrentStream::take
14/// [`ConcurrentStream`]: trait.ConcurrentStream.html
15#[derive(Debug)]
16pub struct Take<CS: ConcurrentStream> {
17    inner: CS,
18    limit: usize,
19}
20
21impl<CS: ConcurrentStream> Take<CS> {
22    pub(crate) fn new(inner: CS, limit: usize) -> Self {
23        Self { inner, limit }
24    }
25}
26
27impl<CS: ConcurrentStream> ConcurrentStream for Take<CS> {
28    type Item = CS::Item;
29    type Future = CS::Future;
30
31    async fn drive<C>(self, consumer: C) -> C::Output
32    where
33        C: Consumer<Self::Item, Self::Future>,
34    {
35        self.inner
36            .drive(TakeConsumer {
37                inner: consumer,
38                count: 0,
39                limit: self.limit,
40            })
41            .await
42    }
43
44    // NOTE: this is the only interesting bit in this module. When a limit is
45    // set, this now starts using it.
46    fn concurrency_limit(&self) -> Option<NonZeroUsize> {
47        self.inner.concurrency_limit()
48    }
49
50    fn size_hint(&self) -> (usize, Option<usize>) {
51        self.inner.size_hint()
52    }
53}
54
55#[pin_project]
56struct TakeConsumer<C> {
57    #[pin]
58    inner: C,
59    count: usize,
60    limit: usize,
61}
62impl<C, Item, Fut> Consumer<Item, Fut> for TakeConsumer<C>
63where
64    Fut: Future<Output = Item>,
65    C: Consumer<Item, Fut>,
66{
67    type Output = C::Output;
68
69    async fn send(self: Pin<&mut Self>, future: Fut) -> ConsumerState {
70        let this = self.project();
71        *this.count += 1;
72        let state = this.inner.send(future).await;
73        if this.count >= this.limit {
74            ConsumerState::Break
75        } else {
76            state
77        }
78    }
79
80    async fn progress(self: Pin<&mut Self>) -> ConsumerState {
81        let this = self.project();
82        this.inner.progress().await
83    }
84
85    async fn flush(self: Pin<&mut Self>) -> Self::Output {
86        let this = self.project();
87        this.inner.flush().await
88    }
89}
90
91#[cfg(test)]
92mod test {
93    use crate::prelude::*;
94    use futures_lite::stream;
95
96    #[test]
97    fn enumerate() {
98        futures_lite::future::block_on(async {
99            let mut n = 0;
100            stream::iter(std::iter::from_fn(|| {
101                let v = n;
102                n += 1;
103                Some(v)
104            }))
105            .co()
106            .take(5)
107            .for_each(|n| async move { assert!(n < 5) })
108            .await;
109        });
110    }
111}