apalis_core/poller/
stream.rs

1use std::{
2    pin::Pin,
3    sync::atomic::Ordering,
4    task::{Context, Poll},
5};
6
7use futures::{stream::FusedStream, Stream, StreamExt};
8use pin_project_lite::pin_project;
9
10use super::{controller::Controller, STOPPED};
11
12// Macro for pin projection used in `BackendStream`.
13pin_project! {
14    /// `BackendStream` is a wrapper around another stream `S`.
15    /// It controls the flow of the stream based on the `Controller` state.
16    #[derive(Debug)]
17    pub struct BackendStream<S> {
18        #[pin]
19        stream: S,
20        controller: Controller,
21    }
22}
23
24impl<S> BackendStream<S> {
25    /// Creates a new `BackendStream` from a given stream and a shared `Controller`.
26    pub fn new(stream: S, controller: Controller) -> Self {
27        Self { stream, controller }
28    }
29}
30impl<S: Stream<Item = T> + Unpin, T> Stream for BackendStream<S> {
31    type Item = T;
32
33    fn poll_next(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Option<Self::Item>> {
34        let this = self.get_mut();
35        if this.controller.is_plugged() {
36            match this.stream.poll_next_unpin(cx) {
37                Poll::Ready(Some(item)) => Poll::Ready(Some(item)),
38                Poll::Ready(None) => Poll::Ready(None), // Inner stream is exhausted
39                Poll::Pending => Poll::Pending,
40            }
41        } else if this.controller.is_stopped() {
42            Poll::Ready(None)
43        } else {
44            Poll::Pending
45        }
46    }
47
48    fn size_hint(&self) -> (usize, Option<usize>) {
49        self.stream.size_hint()
50    }
51}
52
53impl<S: Unpin + Stream> FusedStream for BackendStream<S> {
54    fn is_terminated(&self) -> bool {
55        self.controller.state.load(Ordering::Relaxed) == STOPPED
56    }
57}
58
59#[cfg(test)]
60mod tests {
61    use super::*;
62    use futures::stream::{self, StreamExt};
63    use std::pin::Pin;
64    use std::task::{Context, Poll};
65    use tokio::time::{self, Duration};
66    use tokio_stream::wrappers::IntervalStream;
67
68    fn mock_stream() -> impl Stream<Item = i32> {
69        stream::iter(vec![1, 2, 3])
70    }
71
72    fn interval_stream(duration: Duration) -> IntervalStream {
73        IntervalStream::new(time::interval(duration))
74    }
75
76    #[test]
77    fn test_backend_stream_plugged() {
78        let controller = Controller::new();
79        controller.plug();
80        let mut backend_stream = BackendStream::new(mock_stream(), controller);
81
82        let mut context = Context::from_waker(futures::task::noop_waker_ref());
83        match Pin::new(&mut backend_stream).poll_next(&mut context) {
84            Poll::Ready(Some(item)) => assert_eq!(item, 1),
85            _ => panic!("Expected item from stream"),
86        }
87    }
88
89    #[test]
90    fn test_backend_stream_unplugged() {
91        let controller = Controller::new();
92        controller.unplug();
93        let mut backend_stream = BackendStream::new(mock_stream(), controller);
94
95        let mut context = Context::from_waker(futures::task::noop_waker_ref());
96        match Pin::new(&mut backend_stream).poll_next(&mut context) {
97            Poll::Pending => (),
98            _ => panic!("Expected Poll::Pending"),
99        }
100    }
101
102    #[test]
103    fn test_backend_stream_plug_unplug() {
104        let controller = Controller::new();
105        controller.unplug();
106        let mut backend_stream = BackendStream::new(mock_stream(), controller.clone());
107
108        let mut context = Context::from_waker(futures::task::noop_waker_ref());
109        match Pin::new(&mut backend_stream).poll_next(&mut context) {
110            Poll::Pending => (),
111            _ => panic!("Expected Poll::Pending"),
112        };
113        controller.plug();
114
115        match Pin::new(&mut backend_stream).poll_next(&mut context) {
116            Poll::Ready(Some(item)) => assert_eq!(item, 1),
117            _ => panic!("Expected item from stream"),
118        }
119        controller.unplug();
120        match Pin::new(&mut backend_stream).poll_next(&mut context) {
121            Poll::Pending => (),
122            _ => panic!("Expected Poll::Pending"),
123        };
124        controller.plug();
125
126        match Pin::new(&mut backend_stream).poll_next(&mut context) {
127            Poll::Ready(Some(item)) => assert_eq!(item, 2),
128            _ => panic!("Expected item from stream"),
129        }
130    }
131
132    // Test that BackendStream polls items from an interval stream when plugged
133    #[tokio::test]
134    async fn test_backend_stream_with_interval_plugged() {
135        let controller = Controller::new();
136        controller.plug();
137        let mut backend_stream =
138            BackendStream::new(interval_stream(Duration::from_millis(100)), controller);
139
140        // Polling the stream should yield an item
141        backend_stream
142            .next()
143            .await
144            .expect("Expected an item from the stream");
145    }
146
147    #[tokio::test]
148    async fn test_backend_stream_with_interval_unplugged() {
149        let controller = Controller::new();
150        controller.unplug();
151        let mut backend_stream =
152            BackendStream::new(interval_stream(Duration::from_millis(100)), controller);
153
154        // Using tokio::time::timeout to ensure that the stream doesn't yield an item
155        match tokio::time::timeout(Duration::from_millis(200), backend_stream.next()).await {
156            Ok(None) | Err(_) => (), // Expected as stream is unplugged
157            _ => panic!("Expected no item from the stream"),
158        }
159    }
160
161    #[tokio::test]
162    async fn test_backend_stream_interval_plug_unplug() {
163        let controller = Controller::new();
164        controller.unplug();
165        let mut backend_stream = BackendStream::new(
166            interval_stream(Duration::from_millis(100)),
167            controller.clone(),
168        );
169
170        // Using tokio::time::timeout to ensure that the stream doesn't yield an item
171        match tokio::time::timeout(Duration::from_millis(200), backend_stream.next()).await {
172            Err(_) => (),
173            _ => panic!("Expected no item from the stream"),
174        }
175        controller.plug();
176        backend_stream
177            .next()
178            .await
179            .expect("Expected an item from the stream");
180    }
181}