futures_rx/stream_ext/
sample.rs1use std::{
2 pin::Pin,
3 task::{Context, Poll},
4};
5
6use futures::{
7 stream::{Fuse, FusedStream},
8 Stream, StreamExt,
9};
10use pin_project_lite::pin_project;
11
12pin_project! {
13 #[must_use = "streams do nothing unless polled"]
15 pub struct Sample<S1: Stream, S2: Stream> {
16 #[pin]
17 stream: Fuse<S1>,
18 #[pin]
19 sampler: Fuse<S2>,
20 latest_event: Option<S1::Item>,
21 }
22}
23
24impl<S1: Stream, S2: Stream> Sample<S1, S2> {
25 pub(crate) fn new(stream: S1, sampler: S2) -> Self {
26 Self {
27 stream: stream.fuse(),
28 sampler: sampler.fuse(),
29 latest_event: None,
30 }
31 }
32}
33
34impl<S1: Stream, S2: Stream> FusedStream for Sample<S1, S2>
35where
36 S1: FusedStream,
37 S2: FusedStream,
38{
39 fn is_terminated(&self) -> bool {
40 self.stream.is_terminated() || self.sampler.is_terminated()
41 }
42}
43
44impl<S1: Stream, S2: Stream> Stream for Sample<S1, S2> {
45 type Item = S1::Item;
46
47 fn poll_next(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Option<Self::Item>> {
48 let this = self.project();
49
50 if let Poll::Ready(Some(event)) = this.stream.poll_next(cx) {
51 this.latest_event.replace(event);
52
53 cx.waker().wake_by_ref();
54 }
55
56 match this.sampler.poll_next(cx) {
57 Poll::Ready(Some(_)) => {
58 if this.latest_event.is_some() {
59 Poll::Ready(this.latest_event.take())
60 } else {
61 cx.waker().wake_by_ref();
62
63 Poll::Pending
64 }
65 }
66 Poll::Ready(None) => Poll::Ready(this.latest_event.take()),
67 Poll::Pending => Poll::Pending,
68 }
69 }
70
71 fn size_hint(&self) -> (usize, Option<usize>) {
72 let (lower_left, upper_left) = self.stream.size_hint();
73 let (lower_right, upper_right) = self.sampler.size_hint();
74
75 (lower_left.min(lower_right), upper_left.max(upper_right))
76 }
77}
78
79#[cfg(test)]
80mod test {
81 use futures::{executor::block_on, StreamExt};
82 use futures_time::time::Duration;
83
84 use crate::RxExt;
85
86 #[test]
87 fn smoke() {
88 block_on(async {
89 let stream = futures_time::stream::interval(Duration::from_millis(20))
90 .take(6)
91 .enumerate()
92 .map(|(index, _)| index);
93 let sampler = futures_time::stream::interval(Duration::from_millis(50)).take(6);
94 let all_events = stream.sample(sampler).collect::<Vec<_>>().await;
95
96 assert_eq!(all_events, [1, 3, 5]);
97 });
98 }
99}