par_stream/
state_stream.rs

1//! Stream and handle types for the [`with_state`](crate::stream::StreamExt::with_state) method.
2
3use crate::common::*;
4use tokio::sync::oneshot;
5
6/// Stream for the [`with_state`](super::StreamExt::with_state) method.
7///
8/// The stream produces a single [handle](Handle) to value `T` and
9/// pauses indefinitely until [`handle.send()`](Handle::send) or
10/// [`handle.close()`](Handle::close). Calling [`handle.send()`](Handle::send)
11/// returns the value to the stream, so that the stream can produce the handle again.
12/// [`handle.close()`](Handle::close) drops the handle and the close the stream.
13#[pin_project]
14pub struct StateStream<T> {
15    #[pin]
16    receiver: Option<oneshot::Receiver<T>>,
17    value: Option<T>,
18}
19
20impl<T> StateStream<T> {
21    /// Creates the stream with initial value `init`.
22    pub fn new(init: T) -> Self {
23        Self {
24            value: Some(init),
25            receiver: None,
26        }
27    }
28}
29
30impl<T> Stream for StateStream<T> {
31    type Item = Handle<T>;
32
33    fn poll_next(self: Pin<&mut Self>, cx: &mut Context) -> Poll<Option<Self::Item>> {
34        let mut this = self.project();
35
36        Ready(loop {
37            if let Some(value) = this.value.take() {
38                let (tx, rx) = oneshot::channel();
39                this.receiver.set(Some(rx));
40                break Some(Handle {
41                    inner: Some(Inner { value, sender: tx }),
42                });
43            } else if let Some(receiver) = this.receiver.as_mut().as_pin_mut() {
44                match ready!(receiver.poll(cx)) {
45                    Ok(value) => {
46                        *this.value = Some(value);
47                        this.receiver.set(None);
48                    }
49                    Err(_) => {
50                        this.receiver.set(None);
51                        break None;
52                    }
53                }
54            } else {
55                break None;
56            }
57        })
58    }
59}
60
61/// The handle maintains an unique reference to the state value for [StateStream].
62pub struct Handle<T> {
63    inner: Option<Inner<T>>,
64}
65
66struct Inner<T> {
67    value: T,
68    sender: oneshot::Sender<T>,
69}
70
71impl<T> Handle<T> {
72    fn inner(&self) -> &Inner<T> {
73        self.inner.as_ref().unwrap()
74    }
75
76    /// Returns the value to the associated stream.
77    pub fn send(mut self) -> Result<(), T> {
78        let Inner { value, sender } = self.inner.take().unwrap();
79        sender.send(value)
80    }
81
82    /// Takes the ownership of value and closes the associated stream.
83    pub fn take(mut self) -> T {
84        self.inner.take().unwrap().value
85    }
86
87    /// Discards the value and closes the associated stream.
88    pub fn close(mut self) {
89        let _ = self.inner.take();
90    }
91}
92
93impl<T> Drop for Handle<T> {
94    fn drop(&mut self) {
95        if let Some(Inner { value, sender }) = self.inner.take() {
96            let _ = sender.send(value);
97        }
98    }
99}
100
101impl<T> Debug for Handle<T>
102where
103    T: Debug,
104{
105    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> Result<(), fmt::Error> {
106        self.inner().value.fmt(f)
107    }
108}
109
110impl<T> Display for Handle<T>
111where
112    T: Display,
113{
114    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> Result<(), fmt::Error> {
115        self.inner().value.fmt(f)
116    }
117}
118
119impl<T> PartialEq<T> for Handle<T>
120where
121    T: PartialEq,
122{
123    fn eq(&self, other: &T) -> bool {
124        self.inner().value.eq(other)
125    }
126}
127
128impl<T> PartialOrd<T> for Handle<T>
129where
130    T: PartialOrd,
131{
132    fn partial_cmp(&self, other: &T) -> Option<cmp::Ordering> {
133        self.inner().value.partial_cmp(other)
134    }
135}
136
137impl<T> Hash for Handle<T>
138where
139    T: Hash,
140{
141    fn hash<H>(&self, state: &mut H)
142    where
143        H: Hasher,
144    {
145        self.inner().value.hash(state);
146    }
147}
148
149impl<T> Deref for Handle<T> {
150    type Target = T;
151
152    fn deref(&self) -> &Self::Target {
153        &self.inner().value
154    }
155}
156
157impl<T> DerefMut for Handle<T> {
158    fn deref_mut(&mut self) -> &mut Self::Target {
159        &mut self.inner.as_mut().unwrap().value
160    }
161}
162
163#[cfg(test)]
164mod tests {
165    use super::*;
166    use crate::{stream::StreamExt as _, utils::async_test};
167
168    async_test! {
169        async fn state_stream_test() {
170            let quota = 100;
171
172            let count: usize = stream::repeat(())
173                .with_state(0)
174                .filter_map(|((), mut cost)| async move {
175                    if *cost < quota {
176                        *cost += 1;
177                        cost.send().unwrap();
178                        Some(())
179                    } else {
180                        cost.close();
181                        None
182                    }
183                })
184                .count()
185                .await;
186
187            assert_eq!(count, quota);
188        }
189
190        async fn state_stream_simple_test() {
191            {
192                let mut state_stream = StateStream::new(0);
193
194                let handle = state_stream.next().await.unwrap();
195                handle.send().unwrap();
196
197                let handle = state_stream.next().await.unwrap();
198                drop(handle);
199
200                let handle = state_stream.next().await.unwrap();
201                handle.take();
202
203                assert!(state_stream.next().await.is_none());
204            }
205
206            {
207                let mut state_stream = StateStream::new(0);
208                let handle = state_stream.next().await.unwrap();
209                drop(state_stream);
210                assert!(handle.send().is_err());
211            }
212
213            {
214                let mut state_stream = StateStream::new(0);
215                let handle = state_stream.next().await.unwrap();
216                handle.close();
217                assert!(state_stream.next().await.is_none());
218            }
219        }
220    }
221}