Skip to main content

rho_core/
event_stream.rs

1use std::sync::{Arc, atomic::{AtomicBool, Ordering}};
2use tokio::sync::{mpsc, oneshot};
3use futures::Stream;
4use std::pin::Pin;
5use std::task::{Context, Poll};
6
7// === Producer / Consumer (from split) ===
8
9pub struct EventStreamProducer<T, R> {
10    sender: mpsc::Sender<T>,
11    result_sender: Option<oneshot::Sender<R>>,
12    is_done: Arc<AtomicBool>,
13}
14
15impl<T: Send + 'static, R: Send + 'static> EventStreamProducer<T, R> {
16    pub async fn push(&self, event: T) -> Result<(), mpsc::error::SendError<T>> {
17        if self.is_done.load(Ordering::Relaxed) {
18            return Ok(());
19        }
20        self.sender.send(event).await
21    }
22
23    pub fn end(&mut self, result: Option<R>) {
24        self.is_done.store(true, Ordering::Relaxed);
25        if let Some(sender) = self.result_sender.take() {
26            if let Some(res) = result {
27                let _ = sender.send(res);
28            }
29        }
30    }
31}
32
33pub struct EventStreamConsumer<T, R> {
34    receiver: mpsc::Receiver<T>,
35    result_receiver: Option<oneshot::Receiver<R>>,
36}
37
38impl<T: Send + 'static, R: Send + 'static> EventStreamConsumer<T, R> {
39    pub async fn next(&mut self) -> Option<T> {
40        self.receiver.recv().await
41    }
42
43    pub async fn result(mut self) -> Option<R> {
44        if let Some(rx) = self.result_receiver.take() {
45            rx.await.ok()
46        } else {
47            None
48        }
49    }
50}
51
52impl<T: Send + 'static, R: Send + 'static> Stream for EventStreamConsumer<T, R> {
53    type Item = T;
54
55    fn poll_next(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Option<Self::Item>> {
56        self.receiver.poll_recv(cx)
57    }
58}
59
60/// Generic event stream for async iteration and result handling
61pub struct EventStream<T, R> {
62    pub(crate) sender: Option<mpsc::Sender<T>>,
63    receiver: Arc<std::sync::Mutex<Option<mpsc::Receiver<T>>>>,
64    final_result_receiver: oneshot::Receiver<R>,
65    pub(crate) final_result_sender: Option<oneshot::Sender<R>>,
66    pub(crate) is_done: Arc<AtomicBool>,
67}
68
69impl<T, R> Default for EventStream<T, R>
70where
71    T: Send + 'static,
72    R: Send + 'static,
73{
74    fn default() -> Self {
75        Self::new()
76    }
77}
78
79impl<T, R> EventStream<T, R>
80where
81    T: Send + 'static,
82    R: Send + 'static,
83{
84    /// Create a new EventStream
85    pub fn new() -> Self {
86        let (tx, rx) = mpsc::channel(32);
87        let (result_tx, result_rx) = oneshot::channel();
88
89        Self {
90            sender: Some(tx),
91            receiver: Arc::new(std::sync::Mutex::new(Some(rx))),
92            final_result_receiver: result_rx,
93            final_result_sender: Some(result_tx),
94            is_done: Arc::new(AtomicBool::new(false)),
95        }
96    }
97
98    /// Push an event to the stream
99    pub async fn push(&self, event: T) -> Result<(), mpsc::error::SendError<T>> {
100        if self.is_done.load(Ordering::Relaxed) {
101            return Ok(()); // Ignore pushes after done
102        }
103        if let Some(ref sender) = self.sender {
104            sender.send(event).await
105        } else {
106            Ok(())
107        }
108    }
109
110    /// End the stream with an optional final result
111    pub fn end(&mut self, result: Option<R>) {
112        self.is_done.store(true, Ordering::Relaxed);
113        if let Some(sender) = self.final_result_sender.take() {
114            if let Some(res) = result {
115                let _ = sender.send(res);
116            }
117        }
118        // Drop the sender to close the channel
119        self.sender.take();
120    }
121
122    /// Get the next event from the stream
123    pub async fn next(&mut self) -> Option<T> {
124        let mut rx = {
125            let mut guard = self.receiver.lock().unwrap();
126            guard.take()
127        };
128        let result = if let Some(ref mut receiver) = rx {
129            receiver.recv().await
130        } else {
131            None
132        };
133        if let Some(receiver) = rx {
134            let mut guard = self.receiver.lock().unwrap();
135            *guard = Some(receiver);
136        }
137        result
138    }
139
140    /// Get the final result when the stream ends
141    pub async fn result(self) -> Result<R, oneshot::error::RecvError> {
142        self.final_result_receiver.await
143    }
144
145    /// Split into a producer/consumer pair. Consumes the EventStream.
146    pub fn split(self) -> (EventStreamProducer<T, R>, EventStreamConsumer<T, R>) {
147        let receiver = {
148            let mut guard = self.receiver.lock().unwrap();
149            guard.take().expect("EventStream receiver already taken")
150        };
151
152        let producer = EventStreamProducer {
153            sender: self.sender.unwrap_or_else(|| mpsc::channel(1).0),
154            result_sender: self.final_result_sender,
155            is_done: self.is_done,
156        };
157
158        let consumer = EventStreamConsumer {
159            receiver,
160            result_receiver: Some(self.final_result_receiver),
161        };
162
163        (producer, consumer)
164    }
165}
166
167impl<T, R> Stream for EventStream<T, R>
168where
169    T: Send + 'static,
170    R: Send + 'static,
171{
172    type Item = T;
173
174    fn poll_next(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Option<Self::Item>> {
175        let mut receiver_guard = self.receiver.lock().unwrap();
176        if let Some(ref mut rx) = *receiver_guard {
177            match rx.poll_recv(cx) {
178                Poll::Ready(Some(item)) => Poll::Ready(Some(item)),
179                Poll::Ready(None) => Poll::Ready(None),
180                Poll::Pending => Poll::Pending,
181            }
182        } else {
183            Poll::Ready(None)
184        }
185    }
186}
187
188#[cfg(test)]
189mod tests {
190    use super::*;
191    use futures::StreamExt;
192    use tokio::test;
193
194    #[test]
195    async fn test_event_stream_collect() {
196        let mut stream = EventStream::<i32, String>::new();
197
198        // Push some events
199        stream.push(1).await.unwrap();
200        stream.push(2).await.unwrap();
201        stream.push(3).await.unwrap();
202
203        // End the stream
204        stream.end(Some("done".to_string()));
205
206        // Collect events
207        let events: Vec<_> = stream.collect().await;
208        assert_eq!(events, vec![1, 2, 3]);
209    }
210
211    #[test]
212    async fn test_event_stream_result() {
213        let mut stream = EventStream::<i32, String>::new();
214
215        // End the stream
216        stream.end(Some("done".to_string()));
217
218        // Get result
219        let result = stream.result().await.unwrap();
220        assert_eq!(result, "done");
221    }
222
223    #[test]
224    async fn test_event_stream_next() {
225        let mut stream = EventStream::<i32, ()>::new();
226
227        stream.push(42).await.unwrap();
228
229        let event = stream.next().await;
230        assert_eq!(event, Some(42));
231    }
232
233    #[test]
234    async fn test_split_producer_consumer() {
235        let stream = EventStream::<i32, String>::new();
236        let (producer, mut consumer) = stream.split();
237
238        // Push from producer, read from consumer
239        producer.push(1).await.unwrap();
240        producer.push(2).await.unwrap();
241        producer.push(3).await.unwrap();
242
243        assert_eq!(consumer.next().await, Some(1));
244        assert_eq!(consumer.next().await, Some(2));
245        assert_eq!(consumer.next().await, Some(3));
246    }
247
248    #[test]
249    async fn test_split_result() {
250        let stream = EventStream::<i32, String>::new();
251        let (mut producer, consumer) = stream.split();
252
253        producer.end(Some("final".to_string()));
254
255        let result = consumer.result().await;
256        assert_eq!(result, Some("final".to_string()));
257    }
258
259    #[test]
260    async fn test_split_consumer_stream_trait() {
261        let stream = EventStream::<i32, String>::new();
262        let (mut producer, consumer) = stream.split();
263
264        // Spawn producer
265        tokio::spawn(async move {
266            producer.push(10).await.unwrap();
267            producer.push(20).await.unwrap();
268            producer.end(Some("done".to_string()));
269        });
270
271        // Collect via Stream trait
272        let events: Vec<_> = consumer.collect().await;
273        assert_eq!(events, vec![10, 20]);
274    }
275}