1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
use pin_project::pin_project;

use super::{ConcurrentStream, Consumer, ConsumerState};
use core::future::Future;
use core::num::NonZeroUsize;
use core::pin::Pin;

/// A concurrent iterator that only iterates over the first `n` iterations of `iter`.
///
/// This `struct` is created by the [`take`] method on [`ConcurrentStream`]. See its
/// documentation for more.
///
/// [`take`]: ConcurrentStream::take
/// [`ConcurrentStream`]: trait.ConcurrentStream.html
#[derive(Debug)]
pub struct Take<CS: ConcurrentStream> {
    inner: CS,
    limit: usize,
}

impl<CS: ConcurrentStream> Take<CS> {
    pub(crate) fn new(inner: CS, limit: usize) -> Self {
        Self { inner, limit }
    }
}

impl<CS: ConcurrentStream> ConcurrentStream for Take<CS> {
    type Item = CS::Item;
    type Future = CS::Future;

    async fn drive<C>(self, consumer: C) -> C::Output
    where
        C: Consumer<Self::Item, Self::Future>,
    {
        self.inner
            .drive(TakeConsumer {
                inner: consumer,
                count: 0,
                limit: self.limit,
            })
            .await
    }

    // NOTE: this is the only interesting bit in this module. When a limit is
    // set, this now starts using it.
    fn concurrency_limit(&self) -> Option<NonZeroUsize> {
        self.inner.concurrency_limit()
    }

    fn size_hint(&self) -> (usize, Option<usize>) {
        self.inner.size_hint()
    }
}

#[pin_project]
struct TakeConsumer<C> {
    #[pin]
    inner: C,
    count: usize,
    limit: usize,
}
impl<C, Item, Fut> Consumer<Item, Fut> for TakeConsumer<C>
where
    Fut: Future<Output = Item>,
    C: Consumer<Item, Fut>,
{
    type Output = C::Output;

    async fn send(self: Pin<&mut Self>, future: Fut) -> ConsumerState {
        let this = self.project();
        *this.count += 1;
        let state = this.inner.send(future).await;
        if this.count >= this.limit {
            ConsumerState::Break
        } else {
            state
        }
    }

    async fn progress(self: Pin<&mut Self>) -> ConsumerState {
        let this = self.project();
        this.inner.progress().await
    }

    async fn flush(self: Pin<&mut Self>) -> Self::Output {
        let this = self.project();
        this.inner.flush().await
    }
}

#[cfg(test)]
mod test {
    use crate::prelude::*;
    use futures_lite::stream;

    #[test]
    fn enumerate() {
        futures_lite::future::block_on(async {
            let mut n = 0;
            stream::iter(std::iter::from_fn(|| {
                let v = n;
                n += 1;
                Some(v)
            }))
            .co()
            .take(5)
            .for_each(|n| async move { assert!(n < 5) })
            .await;
        });
    }
}