par_stream/
state_stream.rs1use crate::common::*;
4use tokio::sync::oneshot;
5
6#[pin_project]
14pub struct StateStream<T> {
15 #[pin]
16 receiver: Option<oneshot::Receiver<T>>,
17 value: Option<T>,
18}
19
20impl<T> StateStream<T> {
21 pub fn new(init: T) -> Self {
23 Self {
24 value: Some(init),
25 receiver: None,
26 }
27 }
28}
29
30impl<T> Stream for StateStream<T> {
31 type Item = Handle<T>;
32
33 fn poll_next(self: Pin<&mut Self>, cx: &mut Context) -> Poll<Option<Self::Item>> {
34 let mut this = self.project();
35
36 Ready(loop {
37 if let Some(value) = this.value.take() {
38 let (tx, rx) = oneshot::channel();
39 this.receiver.set(Some(rx));
40 break Some(Handle {
41 inner: Some(Inner { value, sender: tx }),
42 });
43 } else if let Some(receiver) = this.receiver.as_mut().as_pin_mut() {
44 match ready!(receiver.poll(cx)) {
45 Ok(value) => {
46 *this.value = Some(value);
47 this.receiver.set(None);
48 }
49 Err(_) => {
50 this.receiver.set(None);
51 break None;
52 }
53 }
54 } else {
55 break None;
56 }
57 })
58 }
59}
60
61pub struct Handle<T> {
63 inner: Option<Inner<T>>,
64}
65
66struct Inner<T> {
67 value: T,
68 sender: oneshot::Sender<T>,
69}
70
71impl<T> Handle<T> {
72 fn inner(&self) -> &Inner<T> {
73 self.inner.as_ref().unwrap()
74 }
75
76 pub fn send(mut self) -> Result<(), T> {
78 let Inner { value, sender } = self.inner.take().unwrap();
79 sender.send(value)
80 }
81
82 pub fn take(mut self) -> T {
84 self.inner.take().unwrap().value
85 }
86
87 pub fn close(mut self) {
89 let _ = self.inner.take();
90 }
91}
92
93impl<T> Drop for Handle<T> {
94 fn drop(&mut self) {
95 if let Some(Inner { value, sender }) = self.inner.take() {
96 let _ = sender.send(value);
97 }
98 }
99}
100
101impl<T> Debug for Handle<T>
102where
103 T: Debug,
104{
105 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> Result<(), fmt::Error> {
106 self.inner().value.fmt(f)
107 }
108}
109
110impl<T> Display for Handle<T>
111where
112 T: Display,
113{
114 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> Result<(), fmt::Error> {
115 self.inner().value.fmt(f)
116 }
117}
118
119impl<T> PartialEq<T> for Handle<T>
120where
121 T: PartialEq,
122{
123 fn eq(&self, other: &T) -> bool {
124 self.inner().value.eq(other)
125 }
126}
127
128impl<T> PartialOrd<T> for Handle<T>
129where
130 T: PartialOrd,
131{
132 fn partial_cmp(&self, other: &T) -> Option<cmp::Ordering> {
133 self.inner().value.partial_cmp(other)
134 }
135}
136
137impl<T> Hash for Handle<T>
138where
139 T: Hash,
140{
141 fn hash<H>(&self, state: &mut H)
142 where
143 H: Hasher,
144 {
145 self.inner().value.hash(state);
146 }
147}
148
149impl<T> Deref for Handle<T> {
150 type Target = T;
151
152 fn deref(&self) -> &Self::Target {
153 &self.inner().value
154 }
155}
156
157impl<T> DerefMut for Handle<T> {
158 fn deref_mut(&mut self) -> &mut Self::Target {
159 &mut self.inner.as_mut().unwrap().value
160 }
161}
162
163#[cfg(test)]
164mod tests {
165 use super::*;
166 use crate::{stream::StreamExt as _, utils::async_test};
167
168 async_test! {
169 async fn state_stream_test() {
170 let quota = 100;
171
172 let count: usize = stream::repeat(())
173 .with_state(0)
174 .filter_map(|((), mut cost)| async move {
175 if *cost < quota {
176 *cost += 1;
177 cost.send().unwrap();
178 Some(())
179 } else {
180 cost.close();
181 None
182 }
183 })
184 .count()
185 .await;
186
187 assert_eq!(count, quota);
188 }
189
190 async fn state_stream_simple_test() {
191 {
192 let mut state_stream = StateStream::new(0);
193
194 let handle = state_stream.next().await.unwrap();
195 handle.send().unwrap();
196
197 let handle = state_stream.next().await.unwrap();
198 drop(handle);
199
200 let handle = state_stream.next().await.unwrap();
201 handle.take();
202
203 assert!(state_stream.next().await.is_none());
204 }
205
206 {
207 let mut state_stream = StateStream::new(0);
208 let handle = state_stream.next().await.unwrap();
209 drop(state_stream);
210 assert!(handle.send().is_err());
211 }
212
213 {
214 let mut state_stream = StateStream::new(0);
215 let handle = state_stream.next().await.unwrap();
216 handle.close();
217 assert!(state_stream.next().await.is_none());
218 }
219 }
220 }
221}