1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
use std::{
    pin::Pin,
    sync::atomic::Ordering,
    task::{Context, Poll},
};

use futures::{stream::FusedStream, Stream, StreamExt};
use pin_project_lite::pin_project;

use super::{controller::Controller, STOPPED};

// Macro for pin projection used in `BackendStream`.
pin_project! {
    /// `BackendStream` is a wrapper around another stream `S`.
    /// It controls the flow of the stream based on the `Controller` state.
    #[derive(Debug)]
    pub struct BackendStream<S> {
        #[pin]
        stream: S,
        controller: Controller,
    }
}

impl<S> BackendStream<S> {
    /// Creates a new `BackendStream` from a given stream and a shared `Controller`.
    pub fn new(stream: S, controller: Controller) -> Self {
        Self { stream, controller }
    }
}
impl<S: Stream<Item = T> + Unpin, T> Stream for BackendStream<S> {
    type Item = T;

    fn poll_next(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Option<Self::Item>> {
        let this = self.get_mut();
        if this.controller.is_plugged() {
            match this.stream.poll_next_unpin(cx) {
                Poll::Ready(Some(item)) => Poll::Ready(Some(item)),
                Poll::Ready(None) => Poll::Ready(None), // Inner stream is exhausted
                Poll::Pending => Poll::Pending,
            }
        } else if this.controller.is_stopped() {
            Poll::Ready(None)
        } else {
            Poll::Pending
        }
    }

    fn size_hint(&self) -> (usize, Option<usize>) {
        self.stream.size_hint()
    }
}

impl<S: Unpin + Stream> FusedStream for BackendStream<S> {
    fn is_terminated(&self) -> bool {
        self.controller.state.load(Ordering::Relaxed) == STOPPED
    }
}

#[cfg(test)]
mod tests {
    use super::*;
    use futures::stream::{self, StreamExt};
    use std::pin::Pin;
    use std::task::{Context, Poll};
    use tokio::time::{self, Duration};
    use tokio_stream::wrappers::IntervalStream;

    fn mock_stream() -> impl Stream<Item = i32> {
        stream::iter(vec![1, 2, 3])
    }

    fn interval_stream(duration: Duration) -> IntervalStream {
        IntervalStream::new(time::interval(duration))
    }

    #[test]
    fn test_backend_stream_plugged() {
        let controller = Controller::new();
        controller.plug();
        let mut backend_stream = BackendStream::new(mock_stream(), controller);

        let mut context = Context::from_waker(futures::task::noop_waker_ref());
        match Pin::new(&mut backend_stream).poll_next(&mut context) {
            Poll::Ready(Some(item)) => assert_eq!(item, 1),
            _ => panic!("Expected item from stream"),
        }
    }

    #[test]
    fn test_backend_stream_unplugged() {
        let controller = Controller::new();
        controller.unplug();
        let mut backend_stream = BackendStream::new(mock_stream(), controller);

        let mut context = Context::from_waker(futures::task::noop_waker_ref());
        match Pin::new(&mut backend_stream).poll_next(&mut context) {
            Poll::Pending => (),
            _ => panic!("Expected Poll::Pending"),
        }
    }

    #[test]
    fn test_backend_stream_plug_unplug() {
        let controller = Controller::new();
        controller.unplug();
        let mut backend_stream = BackendStream::new(mock_stream(), controller.clone());

        let mut context = Context::from_waker(futures::task::noop_waker_ref());
        match Pin::new(&mut backend_stream).poll_next(&mut context) {
            Poll::Pending => (),
            _ => panic!("Expected Poll::Pending"),
        };
        controller.plug();

        match Pin::new(&mut backend_stream).poll_next(&mut context) {
            Poll::Ready(Some(item)) => assert_eq!(item, 1),
            _ => panic!("Expected item from stream"),
        }
        controller.unplug();
        match Pin::new(&mut backend_stream).poll_next(&mut context) {
            Poll::Pending => (),
            _ => panic!("Expected Poll::Pending"),
        };
        controller.plug();

        match Pin::new(&mut backend_stream).poll_next(&mut context) {
            Poll::Ready(Some(item)) => assert_eq!(item, 2),
            _ => panic!("Expected item from stream"),
        }
    }

    // Test that BackendStream polls items from an interval stream when plugged
    #[tokio::test]
    async fn test_backend_stream_with_interval_plugged() {
        let controller = Controller::new();
        controller.plug();
        let mut backend_stream =
            BackendStream::new(interval_stream(Duration::from_millis(100)), controller);

        // Polling the stream should yield an item
        backend_stream
            .next()
            .await
            .expect("Expected an item from the stream");
    }

    #[tokio::test]
    async fn test_backend_stream_with_interval_unplugged() {
        let controller = Controller::new();
        controller.unplug();
        let mut backend_stream =
            BackendStream::new(interval_stream(Duration::from_millis(100)), controller);

        // Using tokio::time::timeout to ensure that the stream doesn't yield an item
        match tokio::time::timeout(Duration::from_millis(200), backend_stream.next()).await {
            Ok(None) | Err(_) => (), // Expected as stream is unplugged
            _ => panic!("Expected no item from the stream"),
        }
    }

    #[tokio::test]
    async fn test_backend_stream_interval_plug_unplug() {
        let controller = Controller::new();
        controller.unplug();
        let mut backend_stream = BackendStream::new(
            interval_stream(Duration::from_millis(100)),
            controller.clone(),
        );

        // Using tokio::time::timeout to ensure that the stream doesn't yield an item
        match tokio::time::timeout(Duration::from_millis(200), backend_stream.next()).await {
            Err(_) => (),
            _ => panic!("Expected no item from the stream"),
        }
        controller.plug();
        backend_stream
            .next()
            .await
            .expect("Expected an item from the stream");
    }
}