1use std::{pin::Pin, time::Duration};
6
7use bytes::Bytes;
8use http_body::Body;
9use http_body_util::BodyExt;
10
11use crate::{error::BoxError, response::Response};
12
13const ERR_INVALID_CONTENT_TYPE: &str = "Content-Type returned by server is NOT text/event-stream";
15
16const DATA: &str = "data";
18const EVENT: &str = "event";
19const ID: &str = "id";
20const RETRY: &str = "retry";
21
22const BIT_DATA: u8 = 0b0001;
27const BIT_EVENT: u8 = 0b0010;
28const BIT_ID: u8 = 0b0100;
29const BIT_RETRY: u8 = 0b1000;
30
31pub trait SseExt<B>
33where
34 B: Body<Data = Bytes> + Unpin,
35 B::Error: Into<BoxError>,
36{
37 fn into_sse(self) -> Result<SseReader<B>, BoxError>;
41}
42
43impl<B> SseExt<B> for Response<B>
44where
45 B: Body<Data = Bytes> + Unpin,
46 B::Error: Into<BoxError>,
47{
48 fn into_sse(self) -> Result<SseReader<B>, BoxError> {
49 SseReader::into_sse(self)
50 }
51}
52
53#[derive(Debug, Default, Clone)]
55pub struct SseEvent {
56 pub data: Option<String>,
58 pub event: Option<String>,
60 pub id: Option<String>,
62 pub retry: Option<Duration>,
64}
65
66impl SseEvent {
67 pub fn event(&self) -> &str {
69 self.event.as_deref().unwrap_or("message")
70 }
71
72 pub fn data(&self) -> Option<&str> {
74 self.data.as_deref()
75 }
76
77 pub fn id(&self) -> Option<&str> {
79 self.id.as_deref()
80 }
81
82 pub fn retry(&self) -> Option<Duration> {
84 self.retry
85 }
86}
87
88#[derive(Default)]
93struct EventBuffer {
94 bitset: u8,
96 data: String,
97 event: Option<String>,
98 id: Option<String>,
99 retry: Option<Duration>,
100}
101
102impl EventBuffer {
103 fn reset(&mut self) {
105 self.bitset = 0;
106 self.data.clear();
107 self.event = None;
108 self.id = None;
109 self.retry = None;
110 }
111
112 fn has_field(&self) -> bool {
114 self.bitset != 0
115 }
116
117 fn is_set_id(&self) -> bool {
119 self.bitset & BIT_ID != 0
120 }
121
122 fn dispatch(&mut self) -> SseEvent {
124 let event = SseEvent {
125 event: self.event.take(),
126 data: if self.bitset & BIT_DATA != 0 {
127 Some(std::mem::take(&mut self.data))
128 } else {
129 None
130 },
131 id: self.id.take().filter(|s| !s.is_empty()),
132 retry: self.retry.take(),
133 };
134 self.reset();
135 event
136 }
137}
138
139pub struct SseReader<B> {
143 body: B,
144 buffer: Vec<u8>,
146 last_event_id: String,
150 is_first_line: bool,
152 pending: EventBuffer,
154}
155
156impl<B> SseReader<B>
157where
158 B: Body<Data = Bytes> + Unpin,
159 B::Error: Into<BoxError>,
160{
161 pub fn into_sse(resp: Response<B>) -> Result<Self, BoxError> {
163 if !resp.status().is_success() {
164 return Err(format!("Server returned error status: {}", resp.status()).into());
165 }
166
167 let content_type = resp
169 .headers()
170 .get(http::header::CONTENT_TYPE)
171 .and_then(|v| v.to_str().ok())
172 .unwrap_or("");
173
174 if !content_type.starts_with(mime::TEXT_EVENT_STREAM.essence_str()) {
175 return Err(ERR_INVALID_CONTENT_TYPE.into());
176 }
177
178 Ok(Self {
179 body: resp.into_body(),
180 buffer: Vec::new(),
181 last_event_id: String::new(),
182 is_first_line: true,
183 pending: EventBuffer::default(),
184 })
185 }
186
187 pub fn last_event_id(&self) -> &str {
191 &self.last_event_id
192 }
193
194 pub async fn read(&mut self) -> Result<Option<SseEvent>, BoxError> {
199 loop {
200 while let Some(line) = self.next_line() {
202 if let Some(event) = self.process_line(line)? {
203 return Ok(Some(event));
204 }
205 }
206
207 match Pin::new(&mut self.body).frame().await {
209 Some(Ok(frame)) => {
210 if let Ok(data) = frame.into_data() {
211 self.buffer.extend_from_slice(&data);
212 }
213 }
214 Some(Err(e)) => return Err(e.into()),
215 None => {
216 if !self.buffer.is_empty() {
219 self.buffer.push(b'\n');
220 while let Some(line) = self.next_line() {
221 if let Some(event) = self.process_line(line)? {
222 return Ok(Some(event));
223 }
224 }
225 }
226 if self.pending.has_field() {
228 return Ok(Some(self.dispatch_pending()));
229 }
230 return Ok(None);
231 }
232 }
233 }
234 }
235
236 fn next_line(&mut self) -> Option<String> {
241 let pos = self.buffer.iter().position(|&b| b == b'\n' || b == b'\r')?;
242
243 let terminator = self.buffer[pos];
244 let mut line_bytes: Vec<u8> = self.buffer.drain(..pos).collect();
245
246 self.buffer.remove(0);
248
249 if terminator == b'\r' && self.buffer.first() == Some(&b'\n') {
251 self.buffer.remove(0);
252 }
253
254 if self.is_first_line {
256 self.is_first_line = false;
257 if line_bytes.starts_with(&[0xEF, 0xBB, 0xBF]) {
258 line_bytes.drain(..3);
259 }
260 }
261
262 Some(String::from_utf8_lossy(&line_bytes).into_owned())
263 }
264
265 fn process_line(&mut self, line: String) -> Result<Option<SseEvent>, BoxError> {
269 if line.is_empty() {
270 if self.pending.has_field() {
272 return Ok(Some(self.dispatch_pending()));
273 }
274 return Ok(None);
276 }
277
278 if line.starts_with(':') {
280 return Ok(None);
281 }
282
283 let (field, value) = match line.find(':') {
286 Some(idx) => {
287 let v = line[idx + 1..]
289 .strip_prefix(' ')
290 .unwrap_or(&line[idx + 1..]);
291 (&line[..idx], v.to_string())
292 }
293 None => (line.as_str(), String::new()),
294 };
295
296 match field {
297 DATA => {
298 if self.pending.bitset & BIT_DATA != 0 {
301 self.pending.data.push('\n');
302 }
303 self.pending.data.push_str(&value);
304 self.pending.bitset |= BIT_DATA;
305 }
306 EVENT => {
307 self.pending.event = Some(value);
308 self.pending.bitset |= BIT_EVENT;
309 }
310 ID if !value.contains('\0') => {
312 self.pending.id = Some(value);
313 self.pending.bitset |= BIT_ID;
314 }
315 RETRY => {
316 if let Ok(ms) = value.parse::<u64>() {
318 self.pending.retry = Some(Duration::from_millis(ms));
319 self.pending.bitset |= BIT_RETRY;
320 }
321 }
322 _ => {} }
324
325 Ok(None)
326 }
327
328 fn dispatch_pending(&mut self) -> SseEvent {
330 if self.pending.is_set_id() {
333 self.last_event_id = self.pending.id.as_deref().unwrap_or_default().to_owned();
334 }
335 self.pending.dispatch()
336 }
337}
338
339#[cfg(test)]
340mod sse_reader_tests {
341 use std::time::Duration;
342
343 use bytes::Bytes;
344 use http::header;
345 use http_body_util::Full;
346
347 use super::SseReader;
348 use crate::response::Response;
349
350 fn make_response(body: &'static str) -> Response<Full<Bytes>> {
351 Response::builder()
352 .header(header::CONTENT_TYPE, mime::TEXT_EVENT_STREAM.essence_str())
353 .body(Full::new(Bytes::from_static(body.as_bytes())))
354 .unwrap()
355 }
356
357 #[test]
358 fn rejects_wrong_content_type() {
359 let resp = Response::builder()
360 .header(header::CONTENT_TYPE, "application/json")
361 .body(Full::new(Bytes::new()))
362 .unwrap();
363 assert!(SseReader::into_sse(resp).is_err());
364 }
365
366 #[test]
367 fn rejects_missing_content_type() {
368 let resp = Response::builder().body(Full::new(Bytes::new())).unwrap();
369 assert!(SseReader::into_sse(resp).is_err());
370 }
371
372 #[tokio::test]
373 async fn single_data_field() {
374 let mut reader = SseReader::into_sse(make_response("data: hello\n\n")).unwrap();
375 let event = reader.read().await.unwrap().unwrap();
376 assert_eq!(event.data(), Some("hello"));
377 assert_eq!(event.event(), "message");
378 assert_eq!(event.id(), None);
379 assert_eq!(event.retry(), None);
380 }
381
382 #[tokio::test]
383 async fn single_event_field() {
384 let mut reader = SseReader::into_sse(make_response("event: ping\n\n")).unwrap();
385 let event = reader.read().await.unwrap().unwrap();
386 assert_eq!(event.data(), None);
387 assert_eq!(event.event(), "ping");
388 assert_eq!(event.id(), None);
389 assert_eq!(event.retry(), None);
390 }
391
392 #[tokio::test]
393 async fn single_id_field() {
394 let mut reader = SseReader::into_sse(make_response("id: 42\n\n")).unwrap();
395 let event = reader.read().await.unwrap().unwrap();
396 assert_eq!(event.data(), None);
397 assert_eq!(event.event(), "message");
398 assert_eq!(event.id(), Some("42"));
399 assert_eq!(event.retry(), None);
400 }
401
402 #[tokio::test]
403 async fn single_retry_field() {
404 let mut reader = SseReader::into_sse(make_response("retry: 3000\n\n")).unwrap();
405 let event = reader.read().await.unwrap().unwrap();
406 assert_eq!(event.data(), None);
407 assert_eq!(event.event(), "message");
408 assert_eq!(event.id(), None);
409 assert_eq!(event.retry(), Some(Duration::from_millis(3000)));
410 }
411
412 #[tokio::test]
413 async fn multi_field_event() {
414 let mut reader = SseReader::into_sse(make_response(
415 "event: ping\ndata: hello\ndata: world\nid: first\nretry: 15000\n: test comment\n\n",
416 ))
417 .unwrap();
418 let event = reader.read().await.unwrap().unwrap();
419 assert_eq!(event.event(), "ping");
420 assert_eq!(event.data(), Some("hello\nworld"));
421 assert_eq!(event.id(), Some("first"));
422 assert_eq!(event.retry(), Some(Duration::from_millis(15000)));
423 }
424
425 #[tokio::test]
426 async fn multiline_data() {
427 let mut reader = SseReader::into_sse(make_response(
428 "data: 114\ndata: 514\ndata: 1919\ndata: 810\n\n",
429 ))
430 .unwrap();
431 let event = reader.read().await.unwrap().unwrap();
432 assert_eq!(event.data(), Some("114\n514\n1919\n810"));
433 assert_eq!(event.event(), "message");
434 assert_eq!(event.id(), None);
435 assert_eq!(event.retry(), None);
436 }
437
438 #[tokio::test]
439 async fn empty_data_field() {
440 let mut reader = SseReader::into_sse(make_response("data:\n\n")).unwrap();
441 let event = reader.read().await.unwrap().unwrap();
442 assert_eq!(event.data(), Some(""));
443 assert_eq!(event.event(), "message");
444 assert_eq!(event.id(), None);
445 assert_eq!(event.retry(), None);
446 }
447
448 #[tokio::test]
449 async fn multiple_events() {
450 let mut reader = SseReader::into_sse(make_response(
451 "event: ping\ndata: -\n\nevent: pong\ndata: -\n\n",
452 ))
453 .unwrap();
454
455 let e1 = reader.read().await.unwrap().unwrap();
456 assert_eq!(e1.data(), Some("-"));
457 assert_eq!(e1.event(), "ping");
458 assert_eq!(e1.id(), None);
459 assert_eq!(e1.retry(), None);
460
461 let e2 = reader.read().await.unwrap().unwrap();
462 assert_eq!(e2.data(), Some("-"));
463 assert_eq!(e2.event(), "pong");
464 assert_eq!(e2.id(), None);
465 assert_eq!(e2.retry(), None);
466
467 assert!(reader.read().await.unwrap().is_none());
468 }
469
470 #[tokio::test]
471 async fn returns_none_on_empty_stream() {
472 let mut reader = SseReader::into_sse(make_response("")).unwrap();
473 assert!(reader.read().await.unwrap().is_none());
474 }
475
476 #[tokio::test]
477 async fn returns_none_after_last_event() {
478 let mut reader = SseReader::into_sse(make_response("data: hello\n\n")).unwrap();
479 reader.read().await.unwrap().unwrap();
480 assert!(reader.read().await.unwrap().is_none());
481 }
482
483 #[tokio::test]
484 async fn comments_are_ignored() {
485 let mut reader =
486 SseReader::into_sse(make_response(": ping\n: pong\n\ndata: hello\n\n")).unwrap();
487 let event = reader.read().await.unwrap().unwrap();
488 assert_eq!(event.data(), Some("hello"));
489 assert_eq!(event.event(), "message");
490 assert_eq!(event.id(), None);
491 assert_eq!(event.retry(), None);
492 assert!(reader.read().await.unwrap().is_none());
493 }
494
495 #[tokio::test]
496 async fn last_event_id_tracks_across_events() {
497 let mut reader = SseReader::into_sse(make_response(
498 "id: 1\ndata: a\n\ndata: b\n\nid: 3\ndata: c\n\n",
499 ))
500 .unwrap();
501
502 reader.read().await.unwrap().unwrap();
503 assert_eq!(reader.last_event_id(), "1");
504
505 reader.read().await.unwrap().unwrap();
507 assert_eq!(reader.last_event_id(), "1");
508
509 reader.read().await.unwrap().unwrap();
510 assert_eq!(reader.last_event_id(), "3");
511 }
512
513 #[tokio::test]
514 async fn empty_id_clears_last_event_id() {
515 let mut reader =
516 SseReader::into_sse(make_response("id: 42\ndata: a\n\nid:\ndata: b\n\n")).unwrap();
517
518 reader.read().await.unwrap().unwrap();
519 assert_eq!(reader.last_event_id(), "42");
520
521 let event = reader.read().await.unwrap().unwrap();
524 assert_eq!(reader.last_event_id(), "");
525 assert_eq!(event.id(), None);
526 }
527
528 #[tokio::test]
529 async fn retry_invalid_is_ignored() {
530 let mut reader = SseReader::into_sse(make_response("retry: abc\ndata: hello\n\n")).unwrap();
531 let event = reader.read().await.unwrap().unwrap();
532 assert_eq!(event.data(), Some("hello"));
533 assert_eq!(event.retry(), None);
534 }
535
536 #[tokio::test]
537 async fn retry_with_suffix_is_ignored() {
538 let mut reader =
539 SseReader::into_sse(make_response("retry: 1000abc\ndata: hello\n\n")).unwrap();
540 let event = reader.read().await.unwrap().unwrap();
541 assert_eq!(event.data(), Some("hello"));
542 assert_eq!(event.retry(), None);
543 }
544
545 #[tokio::test]
546 async fn crlf_line_endings() {
547 let mut reader =
548 SseReader::into_sse(make_response("data: hello\r\ndata: world\r\n\r\n")).unwrap();
549 let event = reader.read().await.unwrap().unwrap();
550 assert_eq!(event.data(), Some("hello\nworld"));
551 }
552
553 #[tokio::test]
554 async fn bare_cr_line_endings() {
555 let mut reader =
556 SseReader::into_sse(make_response("data: hello\rdata: world\r\r")).unwrap();
557 let event = reader.read().await.unwrap().unwrap();
558 assert_eq!(event.data(), Some("hello\nworld"));
559 }
560
561 #[tokio::test]
562 async fn bom_stripped_on_first_line() {
563 let mut body = vec![0xEF, 0xBB, 0xBF];
564 body.extend_from_slice(b"data: hello\n\n");
565 let resp = Response::builder()
566 .header(header::CONTENT_TYPE, mime::TEXT_EVENT_STREAM.essence_str())
567 .body(Full::new(Bytes::from(body)))
568 .unwrap();
569 let mut reader = SseReader::into_sse(resp).unwrap();
570 let event = reader.read().await.unwrap().unwrap();
571 assert_eq!(event.data(), Some("hello"));
572 }
573
574 #[tokio::test]
575 async fn unknown_field_is_ignored() {
576 let mut reader =
577 SseReader::into_sse(make_response("unknown: value\ndata: hello\n\n")).unwrap();
578 let event = reader.read().await.unwrap().unwrap();
579 assert_eq!(event.data(), Some("hello"));
580 }
581
582 #[tokio::test]
583 async fn field_with_no_colon_is_ignored() {
584 let mut reader =
585 SseReader::into_sse(make_response("unknownfield\ndata: hello\n\n")).unwrap();
586 let event = reader.read().await.unwrap().unwrap();
587 assert_eq!(event.data(), Some("hello"));
588 }
589
590 #[tokio::test]
591 async fn event_without_trailing_blank_line_is_flushed() {
592 let mut reader = SseReader::into_sse(make_response("data: hello")).unwrap();
593 let event = reader.read().await.unwrap().unwrap();
594 assert_eq!(event.data(), Some("hello"));
595 }
596}