1#![forbid(unsafe_code)]
2
3use std::borrow::Cow;
4use std::fmt::{self, Write};
5use std::future::Future;
6use std::pin::Pin;
7use std::task::{Context, Poll};
8use std::time::Duration;
9
10use futures_core::{ready, Stream};
11use pin_project_lite::pin_project;
12use puzz_core::body::{Body, BodyExt, Bytes};
13use puzz_core::http::header;
14use puzz_core::response::IntoResponse;
15use puzz_core::{BoxError, Response};
16
17pub struct Sse<S> {
18 stream: S,
19 keep_alive: Option<KeepAlive>,
20}
21
22impl<S> Sse<S> {
23 pub fn new(stream: S) -> Self {
24 Self {
25 stream,
26 keep_alive: None,
27 }
28 }
29
30 pub fn keep_alive(mut self, keep_alive: KeepAlive) -> Self {
31 self.keep_alive = Some(keep_alive);
32 self
33 }
34}
35
36impl<S> fmt::Debug for Sse<S> {
37 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
38 f.debug_struct("Sse")
39 .field("stream", &format_args!("{}", std::any::type_name::<S>()))
40 .field("keep_alive", &self.keep_alive)
41 .finish()
42 }
43}
44
45impl<S, E> IntoResponse for Sse<S>
46where
47 S: Stream<Item = Result<Event, E>> + Send + 'static,
48 E: Into<BoxError>,
49{
50 fn into_response(self) -> Response {
51 let body = SseBody {
52 event_stream: self.stream,
53 keep_alive: self.keep_alive.map(KeepAliveStream::new),
54 };
55
56 Response::builder()
57 .header(header::CONTENT_TYPE, mime::TEXT_EVENT_STREAM.as_ref())
58 .header(header::CACHE_CONTROL, "no-cache")
59 .body(body.boxed())
60 .unwrap()
61 }
62}
63
64pin_project! {
65 struct SseBody<S> {
66 #[pin]
67 event_stream: S,
68 #[pin]
69 keep_alive: Option<KeepAliveStream>,
70 }
71}
72
73impl<S, E> Body for SseBody<S>
74where
75 S: Stream<Item = Result<Event, E>>,
76{
77 type Error = E;
78
79 fn poll_next(
80 self: Pin<&mut Self>,
81 cx: &mut Context<'_>,
82 ) -> Poll<Option<Result<Bytes, Self::Error>>> {
83 let this = self.project();
84
85 match this.event_stream.poll_next(cx) {
86 Poll::Pending => {
87 if let Some(keep_alive) = this.keep_alive.as_pin_mut() {
88 keep_alive
89 .poll_event(cx)
90 .map(|e| Some(Ok(Bytes::from(e.to_string()))))
91 } else {
92 Poll::Pending
93 }
94 }
95 Poll::Ready(Some(Ok(event))) => {
96 if let Some(keep_alive) = this.keep_alive.as_pin_mut() {
97 keep_alive.reset();
98 }
99 Poll::Ready(Some(Ok(Bytes::from(event.to_string()))))
100 }
101 Poll::Ready(Some(Err(error))) => Poll::Ready(Some(Err(error))),
102 Poll::Ready(None) => Poll::Ready(None),
103 }
104 }
105}
106
107#[derive(Debug, Default)]
108pub struct Event {
109 id: Option<String>,
110 data: Option<DataType>,
111 event: Option<String>,
112 comment: Option<String>,
113 retry: Option<Duration>,
114}
115
116#[derive(Debug)]
117enum DataType {
118 Text(String),
119
120 Json(String),
121}
122
123impl Event {
124 pub fn data<T>(mut self, data: T) -> Event
125 where
126 T: Into<String>,
127 {
128 let data = data.into();
129 assert_eq!(
130 memchr::memchr(b'\r', data.as_bytes()),
131 None,
132 "SSE data cannot contain carriage returns",
133 );
134 self.data = Some(DataType::Text(data));
135 self
136 }
137
138 pub fn json_data<T>(mut self, data: T) -> serde_json::Result<Event>
139 where
140 T: serde::Serialize,
141 {
142 self.data = Some(DataType::Json(serde_json::to_string(&data)?));
143 Ok(self)
144 }
145
146 pub fn comment<T>(mut self, comment: T) -> Event
147 where
148 T: Into<String>,
149 {
150 let comment = comment.into();
151 assert_eq!(
152 memchr::memchr2(b'\r', b'\n', comment.as_bytes()),
153 None,
154 "SSE comment cannot contain newlines or carriage returns"
155 );
156 self.comment = Some(comment);
157 self
158 }
159
160 pub fn event<T>(mut self, event: T) -> Event
161 where
162 T: Into<String>,
163 {
164 let event = event.into();
165 assert_eq!(
166 memchr::memchr2(b'\r', b'\n', event.as_bytes()),
167 None,
168 "SSE event name cannot contain newlines or carriage returns"
169 );
170 self.event = Some(event);
171 self
172 }
173
174 pub fn retry(mut self, duration: Duration) -> Event {
175 self.retry = Some(duration);
176 self
177 }
178
179 pub fn id<T>(mut self, id: T) -> Event
180 where
181 T: Into<String>,
182 {
183 let id = id.into();
184 assert_eq!(
185 memchr::memchr3(b'\r', b'\n', b'\0', id.as_bytes()),
186 None,
187 "Event ID cannot contain newlines, carriage returns or null characters",
188 );
189 self.id = Some(id);
190 self
191 }
192}
193
194impl fmt::Display for Event {
195 fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
196 if let Some(comment) = &self.comment {
197 ":".fmt(f)?;
198 comment.fmt(f)?;
199 f.write_char('\n')?;
200 }
201
202 if let Some(event) = &self.event {
203 "event: ".fmt(f)?;
204 event.fmt(f)?;
205 f.write_char('\n')?;
206 }
207
208 match &self.data {
209 Some(DataType::Text(data)) => {
210 for line in data.split('\n') {
211 "data: ".fmt(f)?;
212 line.fmt(f)?;
213 f.write_char('\n')?;
214 }
215 }
216
217 Some(DataType::Json(data)) => {
218 "data:".fmt(f)?;
219 data.fmt(f)?;
220 f.write_char('\n')?;
221 }
222 None => {}
223 }
224
225 if let Some(id) = &self.id {
226 "id: ".fmt(f)?;
227 id.fmt(f)?;
228 f.write_char('\n')?;
229 }
230
231 if let Some(duration) = &self.retry {
232 "retry:".fmt(f)?;
233
234 let secs = duration.as_secs();
235 let millis = duration.subsec_millis();
236
237 if secs > 0 {
238 secs.fmt(f)?;
240
241 if millis < 10 {
243 f.write_str("00")?;
244 } else if millis < 100 {
245 f.write_char('0')?;
246 }
247 }
248
249 millis.fmt(f)?;
251
252 f.write_char('\n')?;
253 }
254
255 f.write_char('\n')?;
256
257 Ok(())
258 }
259}
260
261#[derive(Debug, Clone)]
262pub struct KeepAlive {
263 comment_text: Cow<'static, str>,
264 max_interval: Duration,
265}
266
267impl KeepAlive {
268 pub fn new() -> Self {
269 Self {
270 comment_text: Cow::Borrowed(""),
271 max_interval: Duration::from_secs(15),
272 }
273 }
274
275 pub fn interval(mut self, time: Duration) -> Self {
276 self.max_interval = time;
277 self
278 }
279
280 pub fn text<I>(mut self, text: I) -> Self
281 where
282 I: Into<Cow<'static, str>>,
283 {
284 self.comment_text = text.into();
285 self
286 }
287}
288
289impl Default for KeepAlive {
290 fn default() -> Self {
291 Self::new()
292 }
293}
294
295pin_project! {
296 #[derive(Debug)]
297 pub(crate) struct KeepAliveStream {
298 keep_alive: KeepAlive,
299 #[pin]
300 alive_timer: tokio::time::Sleep,
301 }
302}
303
304impl KeepAliveStream {
305 pub(crate) fn new(keep_alive: KeepAlive) -> Self {
306 Self {
307 alive_timer: tokio::time::sleep(keep_alive.max_interval),
308 keep_alive,
309 }
310 }
311
312 pub(crate) fn reset(self: Pin<&mut Self>) {
313 let this = self.project();
314 this.alive_timer
315 .reset(tokio::time::Instant::now() + this.keep_alive.max_interval);
316 }
317
318 pub(crate) fn poll_event(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Event> {
319 let this = self.as_mut().project();
320
321 ready!(this.alive_timer.poll(cx));
322
323 let comment_str = this.keep_alive.comment_text.clone();
324 let event = Event::default().comment(comment_str);
325
326 self.reset();
327
328 Poll::Ready(event)
329 }
330}