futures_concurrency/concurrent_stream/
take.rs1use pin_project::pin_project;
2
3use super::{ConcurrentStream, Consumer, ConsumerState};
4use core::future::Future;
5use core::num::NonZeroUsize;
6use core::pin::Pin;
7
8#[derive(Debug)]
16pub struct Take<CS: ConcurrentStream> {
17 inner: CS,
18 limit: usize,
19}
20
21impl<CS: ConcurrentStream> Take<CS> {
22 pub(crate) fn new(inner: CS, limit: usize) -> Self {
23 Self { inner, limit }
24 }
25}
26
27impl<CS: ConcurrentStream> ConcurrentStream for Take<CS> {
28 type Item = CS::Item;
29 type Future = CS::Future;
30
31 async fn drive<C>(self, consumer: C) -> C::Output
32 where
33 C: Consumer<Self::Item, Self::Future>,
34 {
35 self.inner
36 .drive(TakeConsumer {
37 inner: consumer,
38 count: 0,
39 limit: self.limit,
40 })
41 .await
42 }
43
44 fn concurrency_limit(&self) -> Option<NonZeroUsize> {
47 self.inner.concurrency_limit()
48 }
49
50 fn size_hint(&self) -> (usize, Option<usize>) {
51 self.inner.size_hint()
52 }
53}
54
55#[pin_project]
56struct TakeConsumer<C> {
57 #[pin]
58 inner: C,
59 count: usize,
60 limit: usize,
61}
62impl<C, Item, Fut> Consumer<Item, Fut> for TakeConsumer<C>
63where
64 Fut: Future<Output = Item>,
65 C: Consumer<Item, Fut>,
66{
67 type Output = C::Output;
68
69 async fn send(self: Pin<&mut Self>, future: Fut) -> ConsumerState {
70 let this = self.project();
71 *this.count += 1;
72 let state = this.inner.send(future).await;
73 if this.count >= this.limit {
74 ConsumerState::Break
75 } else {
76 state
77 }
78 }
79
80 async fn progress(self: Pin<&mut Self>) -> ConsumerState {
81 let this = self.project();
82 this.inner.progress().await
83 }
84
85 async fn flush(self: Pin<&mut Self>) -> Self::Output {
86 let this = self.project();
87 this.inner.flush().await
88 }
89}
90
91#[cfg(test)]
92mod test {
93 use crate::prelude::*;
94 use futures_lite::stream;
95
96 #[test]
97 fn enumerate() {
98 futures_lite::future::block_on(async {
99 let mut n = 0;
100 stream::iter(std::iter::from_fn(|| {
101 let v = n;
102 n += 1;
103 Some(v)
104 }))
105 .co()
106 .take(5)
107 .for_each(|n| async move { assert!(n < 5) })
108 .await;
109 });
110 }
111}