1pub mod error;
30
31use std::{pin::Pin, time::Duration};
32
33use async_stream::try_stream;
34use reqwest::{
35 Response, StatusCode,
36 header::{CONTENT_TYPE, HeaderValue},
37};
38use tokio::io::AsyncBufReadExt;
39use tokio_stream::{Stream, StreamExt};
40use tokio_util::io::StreamReader;
41
42use crate::error::{EventError, EventSourceError};
43
44pub type ServerSentEvents = Pin<Box<dyn Stream<Item = Result<Event, EventError>>>>;
46
47pub static MIME_EVENT_STREAM: &[u8] = b"text/event-stream";
48
49fn is_event_stream(value: &HeaderValue) -> bool {
51 value
52 .as_bytes()
53 .split(|&b| b == b';')
54 .next()
55 .unwrap_or(b"")
56 .trim_ascii()
57 .eq_ignore_ascii_case(MIME_EVENT_STREAM)
58}
59
60struct EventBuffer {
66 event_type: String,
67 data: String,
68 last_event_id: Option<String>,
69 retry: Option<Duration>,
70}
71
72impl EventBuffer {
73 #[allow(clippy::new_without_default)]
75 fn new() -> Self {
76 Self {
77 event_type: String::new(),
78 data: String::new(),
79 last_event_id: None,
80 retry: None,
81 }
82 }
83
84 fn produce_event(&mut self) -> Option<Event> {
88 let event = if self.data.is_empty() {
89 None
90 } else {
91 Some(Event {
92 event_type: if self.event_type.is_empty() {
93 "message".to_string()
94 } else {
95 self.event_type.clone()
96 },
97 data: self.data.clone(),
98 last_event_id: self.last_event_id.clone(),
99 retry: self.retry,
100 })
101 };
102
103 self.event_type.clear();
104 self.data.clear();
105
106 event
107 }
108
109 fn set_event_type(&mut self, event_type: &str) {
111 self.event_type.clear();
112 self.event_type.push_str(event_type);
113 }
114
115 fn push_data(&mut self, data: &str) {
117 if !self.data.is_empty() {
118 self.data.push('\n');
119 }
120 self.data.push_str(data);
121 }
122
123 fn set_id(&mut self, id: &str) {
124 self.last_event_id = Some(id.to_string());
125 }
126
127 fn set_retry(&mut self, retry: Duration) {
128 self.retry = Some(retry);
129 }
130}
131
132fn parse_line(line: &str) -> (&str, &str) {
134 let (field, value) = line.split_once(':').unwrap_or((line, ""));
135 let value = value.strip_prefix(' ').unwrap_or(value);
136 (field, value)
137}
138
139#[derive(Debug, Clone, Eq, PartialEq)]
141pub struct Event {
142 pub event_type: String,
144 pub data: String,
146 pub last_event_id: Option<String>,
148 pub retry: Option<Duration>,
150}
151
152pub trait EventSource {
154 fn events(self) -> impl Future<Output = Result<ServerSentEvents, EventSourceError>> + Send;
165}
166
167impl EventSource for Response {
168 async fn events(self) -> Result<ServerSentEvents, EventSourceError> {
169 let status = self.status();
170 if status != StatusCode::OK {
171 return Err(EventSourceError::BadStatus(status));
172 }
173 match self.headers().get(CONTENT_TYPE) {
174 Some(content_type) => {
175 if !is_event_stream(content_type) {
176 return Err(EventSourceError::BadContentType(Some(
177 content_type.to_owned(),
178 )));
179 }
180 }
181 None => return Err(EventSourceError::BadContentType(None)),
182 }
183
184 let mut stream = StreamReader::new(
185 self.bytes_stream()
186 .map(|result| result.map_err(std::io::Error::other)),
187 );
188
189 let mut line_buffer = String::new();
190 let mut event_buffer = EventBuffer::new();
191
192 let stream = Box::pin(try_stream! {
193 loop {
194 line_buffer.clear();
195 let count = stream.read_line(&mut line_buffer).await.map_err(EventError::IoError)?;
196 if count == 0 {
197 break;
198 }
199 let line = if let Some(line) = line_buffer.strip_suffix('\n') {
200 line
201 } else {
202 &line_buffer
203 };
204
205 if line.is_empty() {
207 if let Some(event) = event_buffer.produce_event() {
208 yield event;
209 }
210 continue;
211 }
212
213 let (field, value) = parse_line(line);
214
215 match field {
216 "event" => {
217 event_buffer.set_event_type(value);
218 }
219 "data" => {
220 event_buffer.push_data(value);
221 }
222 "id" => {
223 event_buffer.set_id(value);
224 }
225 "retry" => {
226 if let Ok(millis) = value.parse() {
227 event_buffer.set_retry(Duration::from_millis(millis));
228 }
229 }
230 _ => {}
231 }
232 }
233 });
234
235 Ok(stream)
236 }
237}
238
239#[cfg(test)]
240mod tests {
241 use super::*;
242
243 #[test]
244 fn parse_line_properly() {
245 let (field, value) = parse_line("event: message");
246 assert_eq!(field, "event");
247 assert_eq!(value, "message");
248
249 let (field, value) = parse_line("non-standard field");
250 assert_eq!(field, "non-standard field");
251 assert_eq!(value, "");
252
253 let (field, value) = parse_line("data:data with : inside");
254 assert_eq!(field, "data");
255 assert_eq!(value, "data with : inside");
256 }
257
258 #[test]
259 fn is_event_stream_accept_valid_values() {
260 assert!(is_event_stream(&HeaderValue::from_static(
261 "text/event-stream"
262 )));
263 assert!(is_event_stream(&HeaderValue::from_static(
264 "text/event-stream; charset=utf-8"
265 )));
266 assert!(is_event_stream(&HeaderValue::from_static(
267 " TEXT/event-stream ; charset=utf-8"
268 )));
269 }
270
271 #[test]
272 fn is_event_stream_reject_invalid_values() {
273 assert!(!is_event_stream(&HeaderValue::from_static("plain/text")));
274 assert!(!is_event_stream(&HeaderValue::from_static(
275 "text/event-but-not-realy"
276 )));
277 }
278}