futures_concurrency/concurrent_stream/
enumerate.rs

1use pin_project::pin_project;
2
3use super::{ConcurrentStream, Consumer};
4use core::future::Future;
5use core::num::NonZeroUsize;
6use core::pin::Pin;
7use core::task::{ready, Context, Poll};
8
9/// A concurrent iterator that yields the current count and the element during iteration.
10///
11/// This `struct` is created by the [`enumerate`] method on [`ConcurrentStream`]. See its
12/// documentation for more.
13///
14/// [`enumerate`]: ConcurrentStream::enumerate
15/// [`ConcurrentStream`]: trait.ConcurrentStream.html
16#[derive(Debug)]
17pub struct Enumerate<CS: ConcurrentStream> {
18    inner: CS,
19}
20
21impl<CS: ConcurrentStream> Enumerate<CS> {
22    pub(crate) fn new(inner: CS) -> Self {
23        Self { inner }
24    }
25}
26
27impl<CS: ConcurrentStream> ConcurrentStream for Enumerate<CS> {
28    type Item = (usize, CS::Item);
29    type Future = EnumerateFuture<CS::Future, CS::Item>;
30
31    async fn drive<C>(self, consumer: C) -> C::Output
32    where
33        C: Consumer<Self::Item, Self::Future>,
34    {
35        self.inner
36            .drive(EnumerateConsumer {
37                inner: consumer,
38                count: 0,
39            })
40            .await
41    }
42
43    fn concurrency_limit(&self) -> Option<NonZeroUsize> {
44        self.inner.concurrency_limit()
45    }
46
47    fn size_hint(&self) -> (usize, Option<usize>) {
48        self.inner.size_hint()
49    }
50}
51
52#[pin_project]
53struct EnumerateConsumer<C> {
54    #[pin]
55    inner: C,
56    count: usize,
57}
58impl<C, Item, Fut> Consumer<Item, Fut> for EnumerateConsumer<C>
59where
60    Fut: Future<Output = Item>,
61    C: Consumer<(usize, Item), EnumerateFuture<Fut, Item>>,
62{
63    type Output = C::Output;
64
65    async fn send(self: Pin<&mut Self>, future: Fut) -> super::ConsumerState {
66        let this = self.project();
67        let count = *this.count;
68        *this.count += 1;
69        this.inner.send(EnumerateFuture::new(future, count)).await
70    }
71
72    async fn progress(self: Pin<&mut Self>) -> super::ConsumerState {
73        let this = self.project();
74        this.inner.progress().await
75    }
76
77    async fn flush(self: Pin<&mut Self>) -> Self::Output {
78        let this = self.project();
79        this.inner.flush().await
80    }
81}
82
83/// Takes a future and maps it to another future via a closure
84#[derive(Debug)]
85#[pin_project::pin_project]
86pub struct EnumerateFuture<FutT, T>
87where
88    FutT: Future<Output = T>,
89{
90    done: bool,
91    #[pin]
92    fut_t: FutT,
93    count: usize,
94}
95
96impl<FutT, T> EnumerateFuture<FutT, T>
97where
98    FutT: Future<Output = T>,
99{
100    fn new(fut_t: FutT, count: usize) -> Self {
101        Self {
102            done: false,
103            fut_t,
104            count,
105        }
106    }
107}
108
109impl<FutT, T> Future for EnumerateFuture<FutT, T>
110where
111    FutT: Future<Output = T>,
112{
113    type Output = (usize, T);
114
115    fn poll(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> {
116        let this = self.project();
117        if *this.done {
118            panic!("future has already been polled to completion once");
119        }
120
121        let item = ready!(this.fut_t.poll(cx));
122        *this.done = true;
123        Poll::Ready((*this.count, item))
124    }
125}
126
127#[cfg(test)]
128mod test {
129    // use crate::concurrent_stream::{ConcurrentStream, IntoConcurrentStream};
130    use crate::prelude::*;
131    use futures_lite::stream;
132    use futures_lite::StreamExt;
133    use std::num::NonZeroUsize;
134
135    #[test]
136    fn enumerate() {
137        futures_lite::future::block_on(async {
138            let mut n = 0;
139            stream::iter(std::iter::from_fn(|| {
140                let v = n;
141                n += 1;
142                Some(v)
143            }))
144            .take(5)
145            .co()
146            .limit(NonZeroUsize::new(1))
147            .enumerate()
148            .for_each(|(index, n)| async move {
149                assert_eq!(index, n);
150            })
151            .await;
152        });
153    }
154}