1use crate::io::{Read, Write, WriteExt};
4
5use super::StatusCode;
6
7pub trait EventData {
9 async fn write_to<W: Write>(self, writer: &mut W) -> Result<(), W::Error>;
11}
12
13impl EventData for core::fmt::Arguments<'_> {
14 async fn write_to<W: Write>(self, writer: &mut W) -> Result<(), W::Error> {
15 writer.write_fmt(self).await
16 }
17}
18
19impl EventData for &str {
20 async fn write_to<W: Write>(self, writer: &mut W) -> Result<(), W::Error> {
21 writer.write_all(self.as_bytes()).await
22 }
23}
24
25#[cfg(feature = "json")]
26impl<T: serde::Serialize> EventData for super::json::Json<T> {
27 async fn write_to<W: Write>(self, writer: &mut W) -> Result<(), W::Error> {
28 self.do_write_to(writer).await
29 }
30}
31
32struct EventWriterState {
33 is_currently_writing_event: core::cell::Cell<bool>,
34 is_running: core::cell::Cell<bool>,
35}
36
37impl EventWriterState {
38 fn new() -> Self {
39 Self {
40 is_currently_writing_event: false.into(),
41 is_running: true.into(),
42 }
43 }
44}
45
46pub struct EventWriter<'a, W: Write> {
48 writer: W,
49 event_writer_state: &'a EventWriterState,
50}
51
52impl<W: Write> EventWriter<'_, W> {
53 async fn do_write<F: core::future::Future>(
54 event_writer_state: &EventWriterState,
55 write_task: F,
56 ) -> F::Output {
57 event_writer_state.is_currently_writing_event.set(true);
58
59 let result = write_task.await;
60
61 event_writer_state.is_currently_writing_event.set(false);
62
63 if !event_writer_state.is_running.get() {
65 return core::future::pending().await;
66 };
67
68 result
69 }
70
71 pub async fn write_keepalive(&mut self) -> Result<(), W::Error> {
73 Self::do_write(self.event_writer_state, async {
74 self.writer.write_all(b":\n\n").await?;
75
76 self.writer.flush().await
77 })
78 .await
79 }
80
81 pub async fn write_event<T: EventData>(
83 &mut self,
84 event: &str,
85 data: T,
86 ) -> Result<(), W::Error> {
87 pub struct DataWriter<W: Write> {
88 writer: W,
89 }
90
91 impl<W: Write> crate::io::ErrorType for DataWriter<W> {
92 type Error = W::Error;
93 }
94
95 impl<W: Write> Write for DataWriter<W> {
96 async fn write(&mut self, buf: &[u8]) -> Result<usize, Self::Error> {
97 for line in buf.split_inclusive(|&b| b == b'\n') {
98 self.writer.write_all(b"data:").await?;
99 self.writer.write_all(line).await?;
100 }
101
102 self.writer.write_all(b"\n").await?;
103
104 Ok(buf.len())
105 }
106
107 async fn flush(&mut self) -> Result<(), Self::Error> {
108 self.writer.flush().await
109 }
110 }
111
112 Self::do_write(self.event_writer_state, async {
113 self.writer.write_all(b"event:").await?;
114 self.writer.write_all(event.as_bytes()).await?;
115 self.writer.write_all(b"\n").await?;
116
117 data.write_to(&mut DataWriter {
118 writer: &mut self.writer,
119 })
120 .await?;
121
122 self.writer.write_all(b"\n").await?;
123
124 self.writer.flush().await
125 })
126 .await
127 }
128}
129
130async fn write_events_until_shutdown<E, F: core::future::Future<Output = Result<(), E>>>(
131 event_writer_state: &EventWriterState,
132 shutdown_signal: impl core::future::Future<Output = ()> + Unpin,
133 mut write_events: core::pin::Pin<&mut F>,
134) -> Result<(), E> {
135 let shutdown_task = async {
136 shutdown_signal.await;
137 event_writer_state.is_running.set(false);
138
139 core::future::pending().await
140 };
141
142 let write_events_task = core::future::poll_fn(|cx| {
143 use core::task::Poll;
144
145 if event_writer_state.is_running.get() {
146 return write_events.as_mut().poll(cx);
147 }
148
149 if !event_writer_state.is_currently_writing_event.get() {
150 return Poll::Ready(Ok(()));
151 }
152
153 if let Poll::Ready(result) = write_events.as_mut().poll(cx) {
154 return Poll::Ready(result);
155 }
156
157 if !event_writer_state.is_currently_writing_event.get() {
158 return Poll::Ready(Ok(()));
159 }
160
161 Poll::Pending
162 });
163
164 crate::futures::select(shutdown_task, write_events_task).await
165}
166
167pub trait EventSource {
169 async fn write_events<W: Write>(self, writer: EventWriter<W>) -> Result<(), W::Error>;
171}
172
173pub struct EventStream<S: EventSource>(pub S);
175
176impl<S: EventSource> EventStream<S> {
177 pub fn into_response(self) -> super::Response<impl super::HeadersIter, impl super::Body> {
179 super::Response {
180 status_code: StatusCode::OK,
181 headers: [
182 ("Cache-Control", "no-cache"),
183 ("Content-Type", "text/event-stream"),
184 ],
185 body: self,
186 }
187 }
188}
189
190impl<S: EventSource> super::Body for EventStream<S> {
191 async fn write_response_body<R: Read, W: Write<Error = R::Error>>(
192 self,
193 connection: super::Connection<'_, R>,
194 mut writer: W,
195 ) -> Result<(), W::Error> {
196 writer.flush().await?;
197
198 let shutdown_signal = connection.shutdown_signal.clone();
199
200 let event_writer_state = &EventWriterState::new();
201
202 let write_events = core::pin::pin!(connection.run_until_disconnection(
203 (),
204 self.0.write_events(EventWriter {
205 writer,
206 event_writer_state
207 })
208 ));
209
210 write_events_until_shutdown(event_writer_state, shutdown_signal, write_events).await
211 }
212}
213
214impl<S: EventSource> super::IntoResponse for EventStream<S> {
215 async fn write_to<R: Read, W: super::ResponseWriter<Error = R::Error>>(
216 self,
217 connection: super::Connection<'_, R>,
218 response_writer: W,
219 ) -> Result<crate::ResponseSent, W::Error> {
220 response_writer
221 .write_response(connection, self.into_response())
222 .await
223 }
224}
225
226#[cfg(test)]
227mod tests {
228 use super::*;
229
230 #[derive(Clone)]
231 struct TestEventSource {
232 event: &'static str,
233 data: &'static str,
234 write_count: usize,
235 }
236
237 impl TestEventSource {
238 fn with_write_count(mut self, write_count: usize) -> Self {
239 self.write_count = write_count;
240 self
241 }
242 }
243
244 impl EventSource for TestEventSource {
245 async fn write_events<W: Write>(
246 self,
247 mut writer: EventWriter<'_, W>,
248 ) -> Result<(), W::Error> {
249 for _ in 0..self.write_count {
250 writer.write_event(self.event, self.data).await?;
251 }
252
253 Ok(())
254 }
255 }
256
257 struct CountWriteSize(usize);
258
259 impl crate::io::ErrorType for CountWriteSize {
260 type Error = core::convert::Infallible;
261 }
262
263 impl Write for CountWriteSize {
264 async fn write(&mut self, buf: &[u8]) -> Result<usize, Self::Error> {
265 let write_size = buf.len();
266
267 self.0 += write_size;
268
269 Ok(write_size)
270 }
271
272 async fn flush(&mut self) -> Result<(), Self::Error> {
273 Ok(())
274 }
275 }
276
277 struct ThrottledWriter {
278 write_size: usize,
279 }
280
281 impl crate::io::ErrorType for ThrottledWriter {
282 type Error = core::convert::Infallible;
283 }
284
285 impl Write for ThrottledWriter {
286 async fn write(&mut self, buf: &[u8]) -> Result<usize, Self::Error> {
287 if buf.is_empty() {
288 Ok(0)
289 } else {
290 self.write_size += 1;
291
292 tokio::task::yield_now().await;
293
294 Ok(1)
295 }
296 }
297
298 async fn flush(&mut self) -> Result<(), Self::Error> {
299 Ok(())
300 }
301 }
302
303 #[tokio::test]
304 #[ntest::timeout(1000)]
305 async fn wait_event_to_finish_writing() {
306 use futures_util::FutureExt;
307
308 let (shutdown_signal_tx, shutdown_signal_rx) = tokio::sync::oneshot::channel::<()>();
309
310 let event_writer_state = &EventWriterState::new();
311
312 let source = TestEventSource {
313 event: "test",
314 data: "test",
315 write_count: 1,
316 };
317
318 let write_size = {
319 let mut count_write_size = CountWriteSize(0);
320
321 let _ = source
322 .clone()
323 .write_events(EventWriter {
324 writer: &mut count_write_size,
325 event_writer_state,
326 })
327 .await;
328
329 count_write_size.0
330 };
331
332 assert!(!event_writer_state.is_currently_writing_event.get());
333 assert!(event_writer_state.is_running.get());
334
335 let mut throttle_writer = ThrottledWriter { write_size: 0 };
336
337 let write_events = async {
338 source
339 .with_write_count(3)
340 .write_events(EventWriter {
341 writer: &mut throttle_writer,
342 event_writer_state,
343 })
344 .await
345 };
346
347 {
348 let task_shutdown_signal = core::pin::pin!(async {
349 let _ = shutdown_signal_rx.await;
350 });
351
352 let task_write_events = core::pin::pin!(write_events);
353
354 let mut task = core::pin::pin!(write_events_until_shutdown(
355 event_writer_state,
356 task_shutdown_signal,
357 task_write_events,
358 ));
359
360 for _ in 0..3 {
361 assert_eq!(task.as_mut().now_or_never(), None);
362 }
363
364 let _ = shutdown_signal_tx.send(());
365
366 let _ = task.await;
367 }
368
369 assert_eq!(throttle_writer.write_size, write_size);
370 }
371}