1use std::sync::{Arc, atomic::{AtomicBool, Ordering}};
2use tokio::sync::{mpsc, oneshot};
3use futures::Stream;
4use std::pin::Pin;
5use std::task::{Context, Poll};
6
7pub struct EventStreamProducer<T, R> {
10 sender: mpsc::Sender<T>,
11 result_sender: Option<oneshot::Sender<R>>,
12 is_done: Arc<AtomicBool>,
13}
14
15impl<T: Send + 'static, R: Send + 'static> EventStreamProducer<T, R> {
16 pub async fn push(&self, event: T) -> Result<(), mpsc::error::SendError<T>> {
17 if self.is_done.load(Ordering::Relaxed) {
18 return Ok(());
19 }
20 self.sender.send(event).await
21 }
22
23 pub fn end(&mut self, result: Option<R>) {
24 self.is_done.store(true, Ordering::Relaxed);
25 if let Some(sender) = self.result_sender.take() {
26 if let Some(res) = result {
27 let _ = sender.send(res);
28 }
29 }
30 }
31}
32
33pub struct EventStreamConsumer<T, R> {
34 receiver: mpsc::Receiver<T>,
35 result_receiver: Option<oneshot::Receiver<R>>,
36}
37
38impl<T: Send + 'static, R: Send + 'static> EventStreamConsumer<T, R> {
39 pub async fn next(&mut self) -> Option<T> {
40 self.receiver.recv().await
41 }
42
43 pub async fn result(mut self) -> Option<R> {
44 if let Some(rx) = self.result_receiver.take() {
45 rx.await.ok()
46 } else {
47 None
48 }
49 }
50}
51
52impl<T: Send + 'static, R: Send + 'static> Stream for EventStreamConsumer<T, R> {
53 type Item = T;
54
55 fn poll_next(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Option<Self::Item>> {
56 self.receiver.poll_recv(cx)
57 }
58}
59
60pub struct EventStream<T, R> {
62 pub(crate) sender: Option<mpsc::Sender<T>>,
63 receiver: Arc<std::sync::Mutex<Option<mpsc::Receiver<T>>>>,
64 final_result_receiver: oneshot::Receiver<R>,
65 pub(crate) final_result_sender: Option<oneshot::Sender<R>>,
66 pub(crate) is_done: Arc<AtomicBool>,
67}
68
69impl<T, R> Default for EventStream<T, R>
70where
71 T: Send + 'static,
72 R: Send + 'static,
73{
74 fn default() -> Self {
75 Self::new()
76 }
77}
78
79impl<T, R> EventStream<T, R>
80where
81 T: Send + 'static,
82 R: Send + 'static,
83{
84 pub fn new() -> Self {
86 let (tx, rx) = mpsc::channel(32);
87 let (result_tx, result_rx) = oneshot::channel();
88
89 Self {
90 sender: Some(tx),
91 receiver: Arc::new(std::sync::Mutex::new(Some(rx))),
92 final_result_receiver: result_rx,
93 final_result_sender: Some(result_tx),
94 is_done: Arc::new(AtomicBool::new(false)),
95 }
96 }
97
98 pub async fn push(&self, event: T) -> Result<(), mpsc::error::SendError<T>> {
100 if self.is_done.load(Ordering::Relaxed) {
101 return Ok(()); }
103 if let Some(ref sender) = self.sender {
104 sender.send(event).await
105 } else {
106 Ok(())
107 }
108 }
109
110 pub fn end(&mut self, result: Option<R>) {
112 self.is_done.store(true, Ordering::Relaxed);
113 if let Some(sender) = self.final_result_sender.take() {
114 if let Some(res) = result {
115 let _ = sender.send(res);
116 }
117 }
118 self.sender.take();
120 }
121
122 pub async fn next(&mut self) -> Option<T> {
124 let mut rx = {
125 let mut guard = self.receiver.lock().unwrap();
126 guard.take()
127 };
128 let result = if let Some(ref mut receiver) = rx {
129 receiver.recv().await
130 } else {
131 None
132 };
133 if let Some(receiver) = rx {
134 let mut guard = self.receiver.lock().unwrap();
135 *guard = Some(receiver);
136 }
137 result
138 }
139
140 pub async fn result(self) -> Result<R, oneshot::error::RecvError> {
142 self.final_result_receiver.await
143 }
144
145 pub fn split(self) -> (EventStreamProducer<T, R>, EventStreamConsumer<T, R>) {
147 let receiver = {
148 let mut guard = self.receiver.lock().unwrap();
149 guard.take().expect("EventStream receiver already taken")
150 };
151
152 let producer = EventStreamProducer {
153 sender: self.sender.unwrap_or_else(|| mpsc::channel(1).0),
154 result_sender: self.final_result_sender,
155 is_done: self.is_done,
156 };
157
158 let consumer = EventStreamConsumer {
159 receiver,
160 result_receiver: Some(self.final_result_receiver),
161 };
162
163 (producer, consumer)
164 }
165}
166
167impl<T, R> Stream for EventStream<T, R>
168where
169 T: Send + 'static,
170 R: Send + 'static,
171{
172 type Item = T;
173
174 fn poll_next(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Option<Self::Item>> {
175 let mut receiver_guard = self.receiver.lock().unwrap();
176 if let Some(ref mut rx) = *receiver_guard {
177 match rx.poll_recv(cx) {
178 Poll::Ready(Some(item)) => Poll::Ready(Some(item)),
179 Poll::Ready(None) => Poll::Ready(None),
180 Poll::Pending => Poll::Pending,
181 }
182 } else {
183 Poll::Ready(None)
184 }
185 }
186}
187
188#[cfg(test)]
189mod tests {
190 use super::*;
191 use futures::StreamExt;
192 use tokio::test;
193
194 #[test]
195 async fn test_event_stream_collect() {
196 let mut stream = EventStream::<i32, String>::new();
197
198 stream.push(1).await.unwrap();
200 stream.push(2).await.unwrap();
201 stream.push(3).await.unwrap();
202
203 stream.end(Some("done".to_string()));
205
206 let events: Vec<_> = stream.collect().await;
208 assert_eq!(events, vec![1, 2, 3]);
209 }
210
211 #[test]
212 async fn test_event_stream_result() {
213 let mut stream = EventStream::<i32, String>::new();
214
215 stream.end(Some("done".to_string()));
217
218 let result = stream.result().await.unwrap();
220 assert_eq!(result, "done");
221 }
222
223 #[test]
224 async fn test_event_stream_next() {
225 let mut stream = EventStream::<i32, ()>::new();
226
227 stream.push(42).await.unwrap();
228
229 let event = stream.next().await;
230 assert_eq!(event, Some(42));
231 }
232
233 #[test]
234 async fn test_split_producer_consumer() {
235 let stream = EventStream::<i32, String>::new();
236 let (producer, mut consumer) = stream.split();
237
238 producer.push(1).await.unwrap();
240 producer.push(2).await.unwrap();
241 producer.push(3).await.unwrap();
242
243 assert_eq!(consumer.next().await, Some(1));
244 assert_eq!(consumer.next().await, Some(2));
245 assert_eq!(consumer.next().await, Some(3));
246 }
247
248 #[test]
249 async fn test_split_result() {
250 let stream = EventStream::<i32, String>::new();
251 let (mut producer, consumer) = stream.split();
252
253 producer.end(Some("final".to_string()));
254
255 let result = consumer.result().await;
256 assert_eq!(result, Some("final".to_string()));
257 }
258
259 #[test]
260 async fn test_split_consumer_stream_trait() {
261 let stream = EventStream::<i32, String>::new();
262 let (mut producer, consumer) = stream.split();
263
264 tokio::spawn(async move {
266 producer.push(10).await.unwrap();
267 producer.push(20).await.unwrap();
268 producer.end(Some("done".to_string()));
269 });
270
271 let events: Vec<_> = consumer.collect().await;
273 assert_eq!(events, vec![10, 20]);
274 }
275}