1use std::pin::Pin;
5use std::task::{Context, Poll};
6use std::time::Duration;
7
8use tokio::sync::mpsc;
9use tokio::time::{self, Interval};
10
11use crate::message::HubEvent;
12
13pub type Bytes = Vec<u8>;
15
16pub struct SseStream {
25 receiver: mpsc::Receiver<HubEvent>,
26 heartbeat: Interval,
27}
28
29impl SseStream {
30 pub fn new(receiver: mpsc::Receiver<HubEvent>, heartbeat_interval: Duration) -> Self {
35 Self {
36 receiver,
37 heartbeat: time::interval(heartbeat_interval),
38 }
39 }
40
41 fn format_event(event: &HubEvent) -> Bytes {
43 event.to_sse_string().into_bytes()
44 }
45
46 fn heartbeat_bytes() -> Bytes {
48 b": heartbeat\n\n".to_vec()
49 }
50}
51
52impl futures_core::Stream for SseStream {
53 type Item = Result<Bytes, std::io::Error>;
54
55 fn poll_next(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Option<Self::Item>> {
56 match self.receiver.poll_recv(cx) {
58 Poll::Ready(Some(event)) => {
59 self.heartbeat.reset();
61 return Poll::Ready(Some(Ok(Self::format_event(&event))));
62 }
63 Poll::Ready(None) => {
64 return Poll::Ready(None);
66 }
67 Poll::Pending => {}
68 }
69
70 match self.heartbeat.poll_tick(cx) {
72 Poll::Ready(_) => Poll::Ready(Some(Ok(Self::heartbeat_bytes()))),
73 Poll::Pending => Poll::Pending,
74 }
75 }
76}
77
78pub fn sse_retry_directive(reconnect_ms: u64) -> Bytes {
82 format!("retry: {}\n\n", reconnect_ms).into_bytes()
83}
84
85pub fn sse_comment(text: &str) -> Bytes {
87 format!(": {}\n\n", text).into_bytes()
88}
89
90#[cfg(test)]
91mod tests {
92 use super::*;
93 use crate::message::HubEvent;
94 use chrono::Utc;
95 use serde_json::json;
96 use tokio::sync::mpsc;
97
98 fn sample_event(id: u64, topic: &str) -> HubEvent {
99 HubEvent {
100 id,
101 topic: topic.to_string(),
102 data: json!({"key": "value"}),
103 timestamp: Utc::now(),
104 }
105 }
106
107 #[test]
108 fn format_event_produces_valid_sse() {
109 let evt = sample_event(1, "test/topic");
110 let bytes = SseStream::format_event(&evt);
111 let text = String::from_utf8(bytes).unwrap();
112
113 assert!(text.starts_with("event: test/topic\n"));
114 assert!(text.contains("data: "));
115 assert!(text.contains("id: 1\n"));
116 assert!(text.ends_with("\n\n"));
117 }
118
119 #[test]
120 fn heartbeat_bytes_format() {
121 let bytes = SseStream::heartbeat_bytes();
122 let text = String::from_utf8(bytes).unwrap();
123 assert_eq!(text, ": heartbeat\n\n");
124 }
125
126 #[test]
127 fn sse_retry_directive_format() {
128 let bytes = sse_retry_directive(3000);
129 let text = String::from_utf8(bytes).unwrap();
130 assert_eq!(text, "retry: 3000\n\n");
131 }
132
133 #[test]
134 fn sse_comment_format() {
135 let bytes = sse_comment("connected");
136 let text = String::from_utf8(bytes).unwrap();
137 assert_eq!(text, ": connected\n\n");
138 }
139
140 #[tokio::test]
141 async fn stream_delivers_events() {
142 use futures_core::Stream;
143 use std::pin::Pin;
144 use std::task::Poll;
145
146 let (tx, rx) = mpsc::channel(256);
147 let mut stream = SseStream::new(rx, Duration::from_secs(60));
148
149 let evt = sample_event(10, "app/deploy");
150 tx.try_send(evt).unwrap();
151
152 let waker = std::task::Waker::noop();
154 let mut cx = std::task::Context::from_waker(&waker);
155 let pin = Pin::new(&mut stream);
156
157 match pin.poll_next(&mut cx) {
158 Poll::Ready(Some(Ok(bytes))) => {
159 let text = String::from_utf8(bytes).unwrap();
160 assert!(text.contains("event: app/deploy"));
161 assert!(text.contains("id: 10"));
162 }
163 other => panic!("expected Ready(Some(Ok)), got {:?}", other),
164 }
165 }
166
167 #[tokio::test]
168 async fn stream_ends_when_channel_closed() {
169 use futures_core::Stream;
170 use std::pin::Pin;
171 use std::task::Poll;
172
173 let (tx, rx) = mpsc::channel::<HubEvent>(256);
174 let mut stream = SseStream::new(rx, Duration::from_secs(60));
175
176 drop(tx); let waker = std::task::Waker::noop();
179 let mut cx = std::task::Context::from_waker(&waker);
180 let pin = Pin::new(&mut stream);
181
182 match pin.poll_next(&mut cx) {
183 Poll::Ready(None) => {} other => panic!("expected Ready(None), got {:?}", other),
185 }
186 }
187
188 #[tokio::test]
189 async fn stream_emits_heartbeat_when_idle() {
190 use futures_core::Stream;
191 use std::pin::Pin;
192 use std::task::Poll;
193
194 let (_tx, rx) = mpsc::channel::<HubEvent>(256);
195 let mut stream = SseStream::new(rx, Duration::from_millis(1));
197
198 tokio::time::sleep(Duration::from_millis(10)).await;
200
201 let waker = std::task::Waker::noop();
202 let mut cx = std::task::Context::from_waker(&waker);
203 let pin = Pin::new(&mut stream);
204
205 match pin.poll_next(&mut cx) {
206 Poll::Ready(Some(Ok(bytes))) => {
207 let text = String::from_utf8(bytes).unwrap();
208 assert_eq!(text, ": heartbeat\n\n");
209 }
210 other => panic!("expected heartbeat, got {:?}", other),
211 }
212 }
213
214 #[tokio::test]
215 async fn multiple_events_delivered_in_order() {
216 use futures_core::Stream;
217 use std::pin::Pin;
218 use std::task::Poll;
219
220 let (tx, rx) = mpsc::channel(256);
221 let mut stream = SseStream::new(rx, Duration::from_secs(60));
222
223 tx.try_send(sample_event(1, "a")).unwrap();
224 tx.try_send(sample_event(2, "b")).unwrap();
225
226 let waker = std::task::Waker::noop();
227 let mut cx = std::task::Context::from_waker(&waker);
228
229 let pin = Pin::new(&mut stream);
231 match pin.poll_next(&mut cx) {
232 Poll::Ready(Some(Ok(bytes))) => {
233 let text = String::from_utf8(bytes).unwrap();
234 assert!(text.contains("id: 1"));
235 }
236 other => panic!("expected first event, got {:?}", other),
237 }
238
239 let pin = Pin::new(&mut stream);
241 match pin.poll_next(&mut cx) {
242 Poll::Ready(Some(Ok(bytes))) => {
243 let text = String::from_utf8(bytes).unwrap();
244 assert!(text.contains("id: 2"));
245 }
246 other => panic!("expected second event, got {:?}", other),
247 }
248 }
249}